Skip to content
11 changes: 6 additions & 5 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
create_kv_caches_with_random)

NUM_BLOCKS = 1024 * 1024
PARTITION_SIZE = 512

NUM_BLOCKS = 256 * 1024
PARTITION_SIZE = 256

@torch.inference_mode()
def main(
Expand Down Expand Up @@ -101,7 +100,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
start_time = time.perf_counter()

# Using default kv_scale
k_scale = v_scale = 1.0
k_scale = v_scale = 0.1

for _ in range(num_iters):
if version == "v1":
Expand Down Expand Up @@ -161,6 +160,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
kv_cache_dtype,
k_scale,
v_scale,
None,
PARTITION_SIZE
)
else:
raise ValueError(f"Invalid version: {version}")
Expand All @@ -180,7 +181,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
if do_profile:
latency = run_benchmark(num_iters=1, profile=True)
else:
latency = run_benchmark(num_iters=1000, profile=False)
latency = run_benchmark(num_iters=10000, profile=False)
print(f"Kernel running time: {latency * 1000000:.3f} us")


Expand Down
Loading
Loading