Skip to content

Commit 2f7dbc9

Browse files
bwastiyewentao256
andauthored
Add batch invariant kernel override for FlashInfer backend [2/n] (vllm-project#25769)
Signed-off-by: Bram Wasti <[email protected]> Signed-off-by: Bram Wasti <[email protected]> Co-authored-by: Wentao Ye <[email protected]>
1 parent ea25a76 commit 2f7dbc9

File tree

3 files changed

+84
-29
lines changed

3 files changed

+84
-29
lines changed

tests/v1/generation/test_batch_invariance.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,21 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
7676
seed.
7777
- Keep max_tokens and max_model_len bounded for speed and memory use.
7878
"""
79-
random.seed(12345)
79+
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
80+
random.seed(seed)
8081

8182
# Allow overrides from environment (useful for CI tuning)
8283
# "facebook/opt-125m" is too small, doesn't reliably test determinism
8384
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
8485
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
85-
batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64"))
86-
assert batch_size >= 2, "Batch size should be >= 2 to mix needle."
86+
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
87+
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
88+
max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
89+
assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle."
8790

8891
# Keep GPU memory usage low to avoid startup allocation failures.
89-
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3"))
90-
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096"))
92+
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4"))
93+
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
9194
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))
9295

9396
# Sampling parameters: longer outputs with a more random-sounding
@@ -111,7 +114,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
111114
# Engine with bs=1 behavior
112115
llm_bs1 = LLM_with_max_seqs(
113116
model=model,
114-
max_num_seqs=1,
117+
max_num_seqs=128,
115118
gpu_memory_utilization=gpu_mem_util,
116119
max_model_len=max_model_len,
117120
swap_space=swap_space_gb,
@@ -126,7 +129,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
126129
# Engine with larger batch limit (e.g., 64)
127130
llm_bsN = LLM_with_max_seqs(
128131
model=model,
129-
max_num_seqs=batch_size,
132+
max_num_seqs=128,
130133
gpu_memory_utilization=gpu_mem_util,
131134
max_model_len=max_model_len,
132135
swap_space=swap_space_gb,
@@ -135,15 +138,17 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
135138
mismatches = 0
136139

137140
for trial in range(num_trials):
138-
# Create a batch of size `batch_size` and insert the needle at
141+
# Create a batch of size `max_batch_size` and insert the needle at
139142
# a random index
140143
prompts: list[str] = []
144+
batch_size = random.randint(max_batch_size // 2, max_batch_size)
141145
needle_pos = random.randint(0, batch_size - 1)
142146
for i in range(batch_size):
143147
if i == needle_pos:
144148
prompts.append(needle_prompt)
145149
else:
146-
prompts.append(_random_prompt())
150+
prompts.append(
151+
_random_prompt(min_random_prompt, max_random_prompt))
147152

148153
# Generate with the larger-batch engine
149154
outputs = llm_bsN.generate(prompts, sampling)
@@ -154,17 +159,19 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
154159
text = needle_output.outputs[0].text
155160

156161
if text != baseline_text:
162+
print(
163+
f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
157164
mismatches += 1
158165

159166
passes = num_trials - mismatches
160167
# Dump how many passed vs failed
161168
print(f"[determinism] total={num_trials}, passed={passes}, "
162-
f"failed={mismatches}, batch_size={batch_size}")
169+
f"failed={mismatches}, max_batch_size={max_batch_size}")
163170

164171
if mismatches > 0:
165172
pytest.fail(
166173
f"Nondeterministic outputs detected: {mismatches} failed out "
167-
f"of {num_trials} trials (batch_size={batch_size}).")
174+
f"of {num_trials} trials (max_batch_size={max_batch_size}).")
168175

169176
finally:
170177
# Ensure engines are shutdown to free GPU/VRAM across test sessions
@@ -196,9 +203,14 @@ def _extract_step_logprobs(request_output):
196203
not torch.cuda.is_available(),
197204
reason="Requires CUDA to match production inference path.",
198205
)
199-
def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
206+
@pytest.mark.parametrize("backend", ["FLEX_ATTENTION", "FLASHINFER"])
207+
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
200208

201-
#model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
209+
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
210+
os.environ["VLLM_ATTENTION_BACKEND"] = backend
211+
212+
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
213+
random.seed(seed)
202214
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
203215
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
204216

@@ -212,10 +224,15 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
212224
prompts = [
213225
"The capital of France is",
214226
"The capital of Germany is",
227+
_random_prompt(10, 1024),
228+
_random_prompt(10, 1024),
229+
_random_prompt(10, 1024),
230+
_random_prompt(10, 1024),
231+
_random_prompt(10, 1024),
215232
]
216233

217234
sp = SamplingParams(
218-
temperature=0.0,
235+
temperature=0.6,
219236
top_p=1.0,
220237
max_tokens=8,
221238
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
@@ -234,25 +251,25 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
234251
"enable logprobs return to run this test.")
235252
bs1_logprobs_per_prompt.append(step_logprobs)
236253

237-
# BS=2: run prompts in a batch and collect logprobs per step for each
254+
# BS=N: run prompts in a batch and collect logprobs per step for each
238255
# prompt.
239256
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
240257
assert len(outs_batched) == len(prompts)
241-
bs2_logprobs_per_prompt = []
258+
bsN_logprobs_per_prompt = []
242259
for o in outs_batched:
243260
step_logprobs = _extract_step_logprobs(o)
244261
if step_logprobs is None:
245262
pytest.skip("Logits are not available on RequestOutput; "
246263
"enable logprobs return to run this test.")
247-
bs2_logprobs_per_prompt.append(step_logprobs)
264+
bsN_logprobs_per_prompt.append(step_logprobs)
248265

249-
# Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs.
250-
for i, (logprobs_bs1, logprobs_bs2) in enumerate(
251-
zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)):
252-
assert len(logprobs_bs1) == len(logprobs_bs2), (
266+
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
267+
for i, (logprobs_bs1, logprobs_bsN) in enumerate(
268+
zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt)):
269+
assert len(logprobs_bs1) == len(logprobs_bsN), (
253270
f"Different number of generation steps for prompt index {i}: "
254-
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)")
255-
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)):
271+
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)")
272+
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
256273
assert a.shape == b.shape, (
257274
f"Logits shape mismatch at prompt {i}, step {t}: "
258275
f"{a.shape} vs {b.shape}")

vllm/model_executor/layers/batch_invariant.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88

99
import torch
1010

11+
import vllm.envs as envs
12+
from vllm.logger import init_logger
1113
from vllm.triton_utils import tl, triton
1214

15+
logger = init_logger(__name__)
16+
1317

1418
def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any,
1519
args: dict[str, Any]) -> dict[str, Any]:
@@ -557,5 +561,12 @@ def vllm_kernel_override_batch_invariant():
557561
def init_batch_invariance():
558562
# this will hit all the csrc overrides as well
559563
if vllm_kernel_override_batch_invariant():
560-
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
564+
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
565+
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
566+
if curr_attn_backend not in supported_backends:
567+
warning = "Forcibly updating attention backend to" \
568+
f" {supported_backends[0]} for batch_invariant. " \
569+
f" Supported backends: {supported_backends}."
570+
logger.warning_once(warning)
571+
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
561572
enable_batch_invariant_mode()

vllm/v1/attention/backends/flashinfer.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
AttentionType)
2121
from vllm.config import CUDAGraphMode, VllmConfig
2222
from vllm.logger import init_logger
23+
from vllm.model_executor.layers.batch_invariant import (
24+
vllm_kernel_override_batch_invariant)
2325
from vllm.model_executor.layers.quantization.utils.quant_utils import (
2426
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
2527
from vllm.platforms import current_platform
@@ -42,6 +44,7 @@
4244
from vllm.v1.kv_cache_interface import AttentionSpec
4345

4446
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
47+
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
4548

4649
FP8_DTYPE = current_platform.fp8_dtype()
4750
FP4_DTYPE = torch.uint8
@@ -263,6 +266,15 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
263266
self._prefill_wrapper = None # Wrapper for prefill/append
264267
self._decode_wrapper = None # Wrapper for decode (general shape)
265268

269+
if vllm_kernel_override_batch_invariant():
270+
self.decode_fixed_split_size = 2048
271+
self.prefill_fixed_split_size = 4096
272+
self.disable_split_kv = True
273+
else:
274+
self.decode_fixed_split_size = -1
275+
self.prefill_fixed_split_size = -1
276+
self.disable_split_kv = False
277+
266278
self.compilation_config = vllm_config.compilation_config
267279
max_num_pages_per_req = cdiv(self.model_config.max_model_len,
268280
self.kv_cache_spec.block_size)
@@ -356,10 +368,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
356368

357369
def _get_workspace_buffer(self):
358370
if self._workspace_buffer is None:
359-
self._workspace_buffer = torch.zeros(
360-
FLASHINFER_WORKSPACE_BUFFER_SIZE,
361-
dtype=torch.uint8,
362-
device=self.device)
371+
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
372+
if vllm_kernel_override_batch_invariant():
373+
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
374+
self._workspace_buffer = torch.zeros(buffer_size,
375+
dtype=torch.uint8,
376+
device=self.device)
363377
return self._workspace_buffer
364378

365379
def _get_prefill_wrapper(self):
@@ -615,6 +629,8 @@ def build(self,
615629
logits_soft_cap=self.logits_soft_cap,
616630
q_data_type=self.q_data_type,
617631
kv_data_type=self.kv_cache_dtype,
632+
fixed_split_size=self.prefill_fixed_split_size,
633+
disable_split_kv=self.disable_split_kv,
618634
)
619635
else:
620636
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
@@ -668,6 +684,8 @@ def build(self,
668684
logits_soft_cap=self.logits_soft_cap,
669685
q_data_type=self.q_data_type,
670686
kv_data_type=self.kv_cache_dtype,
687+
fixed_split_size=self.decode_fixed_split_size,
688+
disable_split_kv=self.disable_split_kv,
671689
)
672690
return attn_metadata
673691

@@ -1048,6 +1066,8 @@ def fast_plan_decode(
10481066
rope_scale: Optional[float] = None,
10491067
rope_theta: Optional[float] = None,
10501068
non_blocking: bool = True,
1069+
fixed_split_size: int = -1,
1070+
disable_split_kv: bool = False,
10511071
) -> None:
10521072
"""
10531073
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
@@ -1085,6 +1105,10 @@ def fast_plan_decode(
10851105
rope_scale,
10861106
rope_theta,
10871107
non_blocking,
1108+
None, # block_tables
1109+
None, # seq_lens
1110+
fixed_split_size,
1111+
disable_split_kv,
10881112
)
10891113
self.vllm_first_call = False
10901114
return
@@ -1130,7 +1154,7 @@ def fast_plan_decode(
11301154
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
11311155

11321156
try:
1133-
# Make sure we pass exactly 15 arguments for tensor core version
1157+
# Make sure we pass exactly 18 arguments for tensor core version
11341158
self._plan_info = self._cached_module.plan(
11351159
self._float_workspace_buffer,
11361160
self._int_workspace_buffer,
@@ -1147,6 +1171,9 @@ def fast_plan_decode(
11471171
head_dim,
11481172
head_dim,
11491173
False, # causal
1174+
window_left,
1175+
fixed_split_size,
1176+
disable_split_kv,
11501177
)
11511178
except Exception as e:
11521179
raise RuntimeError(f"Error in tensor core plan: {e}") from e

0 commit comments

Comments
 (0)