Skip to content

Commit ea3e55a

Browse files
committed
Merge remote-tracking branch 'origin/main' into attention_fusion_v1
Signed-off-by: Gregory Shtrasberg <[email protected]>
2 parents 0e8b47f + 31f09c6 commit ea3e55a

File tree

12 files changed

+564
-36
lines changed

12 files changed

+564
-36
lines changed

examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,14 @@ check_hf_token() {
2121
}
2222

2323
check_num_gpus() {
24-
# can you check if the number of GPUs are >=2 via nvidia-smi?
25-
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
24+
# can you check if the number of GPUs are >=2 via nvidia-smi/rocm-smi?
25+
which rocm-smi > /dev/null 2>&1
26+
if [ $? -ne 0 ]; then
27+
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
28+
else
29+
num_gpus=$(rocm-smi --showid | grep Instinct | wc -l)
30+
fi
31+
2632
if [ "$num_gpus" -lt 2 ]; then
2733
echo "You need at least 2 GPUs to run disaggregated prefill."
2834
exit 1

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def kernel_paged_attention_2d(
6363
stride_v_cache_3: tl.int64, # int
6464
filter_by_query_len: tl.constexpr, # bool
6565
query_start_len_ptr, # [num_seqs+1]
66+
USE_SINKS: tl.constexpr, # bool
6667
USE_FP8: tl.constexpr,
6768
FP8_MIN: tl.constexpr = float8_info.min,
6869
FP8_MAX: tl.constexpr = float8_info.max):
@@ -101,7 +102,7 @@ def kernel_paged_attention_2d(
101102

102103
block_table_offset = seq_idx * block_table_stride
103104

104-
if sink_ptr is None:
105+
if not USE_SINKS:
105106
M = tl.full([num_queries_per_kv_padded],
106107
float("-inf"),
107108
dtype=tl.float32)
@@ -399,5 +400,6 @@ def chunked_prefill_paged_decode(
399400
stride_v_cache_3=value_cache.stride(3),
400401
filter_by_query_len=True,
401402
query_start_len_ptr=query_start_loc,
403+
USE_SINKS=sinks is not None,
402404
USE_FP8=output_scale is not None,
403405
)

vllm/attention/ops/prefix_prefill.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def _fwd_kernel(Q,
8383
num_unroll_cache: tl.constexpr,
8484
num_unroll_request: tl.constexpr,
8585
SKIP_DECODE: tl.constexpr,
86+
USE_SINKS: tl.constexpr,
8687
USE_FP8: tl.constexpr,
8788
MAX_Q_LEN: tl.constexpr = 0,
8889
MAX_CTX_LEN: tl.constexpr = 0,
@@ -132,7 +133,7 @@ def _fwd_kernel(Q,
132133
other=0.0) # [M,D]
133134

134135
# initialize pointer to m and l
135-
if sink_ptr is None:
136+
if not USE_SINKS:
136137
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
137138
else:
138139
m_i = tl.load(
@@ -921,5 +922,6 @@ def context_attention_fwd(q,
921922
num_unroll_request=1,
922923
num_warps=4,
923924
num_stages=1,
925+
USE_SINKS=sinks is not None,
924926
**extra_kargs)
925927
return

vllm/attention/ops/triton_unified_attention.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def kernel_unified_attention_2d(
7878
USE_ALIBI_SLOPES: tl.constexpr, # bool
7979
USE_QQ_BIAS: tl.constexpr, # bool
8080
USE_SOFTCAP: tl.constexpr, # bool
81+
USE_SINKS: tl.constexpr, # bool
8182
SLIDING_WINDOW: tl.constexpr, # int
8283
stride_k_cache_0: tl.int64, # int
8384
stride_k_cache_1: tl.int64, # int
@@ -138,7 +139,7 @@ def kernel_unified_attention_2d(
138139

139140
block_table_offset = seq_idx * block_table_stride
140141

141-
if sink_ptr is None:
142+
if not USE_SINKS:
142143
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
143144
else:
144145
M = tl.load(
@@ -331,6 +332,7 @@ def kernel_unified_attention_3d(
331332
USE_ALIBI_SLOPES: tl.constexpr, # bool
332333
USE_QQ_BIAS: tl.constexpr, # bool
333334
USE_SOFTCAP: tl.constexpr, # bool
335+
USE_SINKS: tl.constexpr, # bool
334336
SLIDING_WINDOW: tl.constexpr, # int
335337
stride_k_cache_0: tl.int64, # int
336338
stride_k_cache_1: tl.int64, # int
@@ -402,14 +404,17 @@ def kernel_unified_attention_3d(
402404

403405
block_table_offset = seq_idx * block_table_stride
404406

405-
if sink_ptr is None or segm_idx != 0:
406-
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
407+
if USE_SINKS:
408+
if segm_idx == 0:
409+
M = tl.load(
410+
sink_ptr + query_offset_1,
411+
mask=query_mask_1,
412+
other=float("-inf"),
413+
).to(dtype=tl.float32)
414+
else:
415+
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
407416
else:
408-
M = tl.load(
409-
sink_ptr + query_offset_1,
410-
mask=query_mask_1,
411-
other=float("-inf"),
412-
).to(dtype=tl.float32)
417+
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
413418

414419
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
415420
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
@@ -735,6 +740,7 @@ def unified_attention(
735740
USE_ALIBI_SLOPES=use_alibi_slopes,
736741
USE_QQ_BIAS=use_qq_bias,
737742
USE_SOFTCAP=(softcap > 0),
743+
USE_SINKS=(sinks is not None),
738744
SLIDING_WINDOW=(1 + window_size[0]),
739745
stride_k_cache_0=k.stride(0),
740746
stride_k_cache_1=k.stride(1),
@@ -807,6 +813,7 @@ def unified_attention(
807813
USE_ALIBI_SLOPES=use_alibi_slopes,
808814
USE_QQ_BIAS=use_qq_bias,
809815
USE_SOFTCAP=(softcap > 0),
816+
USE_SINKS=(sinks is not None),
810817
SLIDING_WINDOW=(1 + window_size[0]),
811818
stride_k_cache_0=k.stride(0),
812819
stride_k_cache_1=k.stride(1),

vllm/entrypoints/openai/serving_engine.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
apply_mistral_chat_template,
3636
parse_chat_messages_futures,
3737
resolve_chat_template_content_format)
38+
from vllm.entrypoints.context import ConversationContext
3839
from vllm.entrypoints.logger import RequestLogger
3940
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
4041
ChatCompletionResponse,
@@ -948,6 +949,61 @@ async def _preprocess_chat(
948949

949950
return conversation, [request_prompt], [engine_prompt]
950951

952+
async def _generate_with_builtin_tools(
953+
self,
954+
request_id: str,
955+
request_prompt: RequestPrompt,
956+
engine_prompt: EngineTokensPrompt,
957+
sampling_params: SamplingParams,
958+
context: ConversationContext,
959+
lora_request: Optional[LoRARequest] = None,
960+
priority: int = 0,
961+
**kwargs,
962+
):
963+
orig_priority = priority
964+
while True:
965+
self._log_inputs(
966+
request_id,
967+
request_prompt,
968+
params=sampling_params,
969+
lora_request=lora_request,
970+
)
971+
generator = self.engine_client.generate(
972+
engine_prompt,
973+
sampling_params,
974+
request_id,
975+
lora_request=lora_request,
976+
priority=priority,
977+
**kwargs,
978+
)
979+
async for res in generator:
980+
context.append_output(res)
981+
# NOTE(woosuk): The stop condition is handled by the engine.
982+
yield context
983+
984+
if not context.need_builtin_tool_call():
985+
# The model did not ask for a tool call, so we're done.
986+
break
987+
988+
# Call the tool and update the context with the result.
989+
tool_output = await context.call_tool()
990+
context.append_output(tool_output)
991+
992+
# TODO: uncomment this and enable tool output streaming
993+
# yield context
994+
995+
# Create inputs for the next turn.
996+
# Render the next prompt token ids.
997+
prompt_token_ids = context.render_for_completion()
998+
engine_prompt = EngineTokensPrompt(
999+
prompt_token_ids=prompt_token_ids)
1000+
request_prompt = prompt_token_ids
1001+
# Update the sampling params.
1002+
sampling_params.max_tokens = (self.max_model_len -
1003+
len(prompt_token_ids))
1004+
# OPTIMIZATION
1005+
priority = orig_priority - 1
1006+
9511007
def _load_prompt_embeds(
9521008
self,
9531009
prompt_embeds: Optional[Union[bytes, list[bytes]]],

vllm/entrypoints/openai/serving_responses.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vllm.engine.protocol import EngineClient
1717
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
1818
ChatTemplateContentFormatOption)
19+
from vllm.entrypoints.context import ConversationContext, SimpleContext
1920
from vllm.entrypoints.logger import RequestLogger
2021
# yapf conflicts with isort for this block
2122
# yapf: disable
@@ -29,7 +30,6 @@
2930
from vllm.entrypoints.openai.serving_engine import OpenAIServing
3031
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
3132
from vllm.logger import init_logger
32-
from vllm.outputs import RequestOutput
3333
from vllm.reasoning import ReasoningParser, ReasoningParserManager
3434
from vllm.sampling_params import SamplingParams
3535
from vllm.transformers_utils.tokenizer import AnyTokenizer
@@ -187,29 +187,27 @@ async def create_responses(
187187
raw_request.state.request_metadata = request_metadata
188188

189189
# Schedule the request and get the result generator.
190-
generators: list[AsyncGenerator[RequestOutput, None]] = []
190+
generators: list[AsyncGenerator[ConversationContext, None]] = []
191191
try:
192192
for i, engine_prompt in enumerate(engine_prompts):
193193
default_max_tokens = self.max_model_len - len(
194194
engine_prompt["prompt_token_ids"])
195195
sampling_params = request.to_sampling_params(
196196
default_max_tokens, self.default_sampling_params)
197197

198-
self._log_inputs(request.request_id,
199-
request_prompts[i],
200-
params=sampling_params,
201-
lora_request=lora_request)
202-
203198
trace_headers = (None if raw_request is None else await
204199
self._get_trace_headers(raw_request.headers))
205200

206-
generator = self.engine_client.generate(
207-
engine_prompt,
208-
sampling_params,
209-
request.request_id,
201+
context = SimpleContext()
202+
generator = self._generate_with_builtin_tools(
203+
request_id=request.request_id,
204+
request_prompt=request_prompts[i],
205+
engine_prompt=engine_prompt,
206+
sampling_params=sampling_params,
207+
context=context,
210208
lora_request=lora_request,
211-
trace_headers=trace_headers,
212209
priority=request.priority,
210+
trace_headers=trace_headers,
213211
)
214212
generators.append(generator)
215213
except ValueError as e:
@@ -277,25 +275,28 @@ async def responses_full_generator(
277275
self,
278276
request: ResponsesRequest,
279277
sampling_params: SamplingParams,
280-
result_generator: AsyncIterator[RequestOutput],
278+
result_generator: AsyncIterator[ConversationContext],
281279
model_name: str,
282280
tokenizer: AnyTokenizer,
283281
request_metadata: RequestResponseMetadata,
284282
created_time: Optional[int] = None,
285283
) -> Union[ErrorResponse, ResponsesResponse]:
286284
if created_time is None:
287285
created_time = int(time.time())
288-
final_res: Optional[RequestOutput] = None
289286

287+
context: Optional[ConversationContext] = None
290288
try:
291-
async for res in result_generator:
292-
final_res = res
289+
async for context in result_generator:
290+
pass
293291
except asyncio.CancelledError:
294292
return self.create_error_response("Client disconnected")
295293
except ValueError as e:
296294
# TODO: Use a vllm-specific Validation Error
297295
return self.create_error_response(str(e))
298296

297+
assert context is not None
298+
assert isinstance(context, SimpleContext)
299+
final_res = context.last_output
299300
assert final_res is not None
300301
assert len(final_res.outputs) == 1
301302
final_output = final_res.outputs[0]

vllm/envs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@
154154
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
155155
VLLM_USE_TRTLLM_CONTEXT_ATTENTION: bool = False
156156
VLLM_USE_TRTLLM_DECODE_ATTENTION: bool = False
157+
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
158+
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
157159

158160

159161
def get_default_cache_root():
@@ -932,6 +934,16 @@ def get_vllm_port() -> Optional[int]:
932934
"VLLM_USE_FLASHINFER_MOE_FP4":
933935
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0"))),
934936

937+
# If set to 1, use the FlashInfer
938+
# MXFP8 (activation) x MXFP4 (weight) MoE backend.
939+
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8":
940+
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))),
941+
942+
# If set to 1, use the FlashInfer
943+
# BF16 (activation) x MXFP4 (weight) MoE backend.
944+
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16":
945+
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0"))),
946+
935947
# Control the cache sized used by the xgrammar compiler. The default
936948
# of 512 MB should be enough for roughly 1000 JSON schemas.
937949
# It can be changed with this variable if needed for some reason.

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
from vllm.model_executor.utils import set_weight_attrs
3434
from vllm.platforms import current_platform
3535
from vllm.platforms.interface import CpuArchEnum
36-
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
36+
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
37+
round_up)
3738
from vllm.utils.flashinfer import has_flashinfer
3839

3940
if current_platform.is_cuda_alike():
@@ -719,6 +720,12 @@ def __init__(
719720

720721
self.global_num_experts = num_experts + num_redundant_experts
721722

723+
# we padding globally so EP buffer allocation works
724+
if quant_config and quant_config.get_name() == "mxfp4" and (
725+
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
726+
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
727+
hidden_size = round_up(hidden_size, 256)
728+
722729
# For smuggling this layer into the fused moe custom op
723730
compilation_config = vllm_config.compilation_config
724731
if prefix in compilation_config.static_forward_context:
@@ -1064,6 +1071,18 @@ def weight_loader(self,
10641071
shard_id: str,
10651072
expert_id: int,
10661073
return_success: bool = False) -> Optional[bool]:
1074+
1075+
if self.quant_config and self.quant_config.get_name() == "mxfp4":
1076+
# (FIXME) for gpt-oss all experts are combined
1077+
if "bias" in weight_name:
1078+
dim1 = loaded_weight.shape[1]
1079+
param.data[:, :dim1].copy_(loaded_weight)
1080+
else:
1081+
dim1 = loaded_weight.shape[1]
1082+
dim2 = loaded_weight.shape[2]
1083+
param.data[:, :dim1, :dim2].copy_(loaded_weight)
1084+
return True if return_success else None
1085+
10671086
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
10681087
if expert_id == -1:
10691088
# Failed to load this param since it's not local to this rank
@@ -1476,13 +1495,20 @@ def maybe_all_reduce_tensor_model_parallel(
14761495

14771496
def forward(self, hidden_states: torch.Tensor,
14781497
router_logits: torch.Tensor):
1498+
og_hidden_states = hidden_states.shape[-1]
1499+
if self.hidden_size != og_hidden_states:
1500+
hidden_states = F.pad(hidden_states,
1501+
(0, self.hidden_size - og_hidden_states),
1502+
mode='constant',
1503+
value=0.0)
14791504
# TODO: Once the OOM issue for the TPU backend is resolved, we will
14801505
# switch to using the moe_forward custom op.
14811506
if current_platform.is_tpu():
14821507
return self.forward_impl(hidden_states, router_logits)
14831508
else:
1484-
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
1485-
self.layer_name)
1509+
return torch.ops.vllm.moe_forward(
1510+
hidden_states, router_logits,
1511+
self.layer_name)[..., :og_hidden_states]
14861512

14871513
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
14881514
full_router_logits: torch.Tensor):

0 commit comments

Comments
 (0)