Skip to content

Commit a6a8898

Browse files
authored
[TRTLLM-9409][feat] Pass MRoPE tensors for EPD disagg (#9758)
* Why? Certain VLMs like the Qwen family need more than just the multimodal embeddings in the language model, and need MRoPE position IDs and deltas. Prior to this commit, only the embeddings could be communicated from the encoder worker to the prefill worker. * What? This commit extends the `DisaggregatedParams` to include the MRoPE information. It also adjusts several pieces of code required to communicate that between E, P and D workers. Closes TRTLLM-9409. Signed-off-by: William Zhang <[email protected]>
1 parent 472fe49 commit a6a8898

File tree

10 files changed

+271
-97
lines changed

10 files changed

+271
-97
lines changed

tensorrt_llm/_torch/models/modeling_llava_next.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
527527
return
528528
if not DISAGG:
529529
self.mm_encoder = LlavaNextVisionModel(model_config)
530+
else:
531+
self.mm_encoder = None
530532

531533
llm_model_config = copy.deepcopy(model_config)
532534
llm_model_config.pretrained_config = model_config.pretrained_config.text_config
@@ -545,7 +547,8 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper):
545547
if isinstance(weight_mapper, LlavaNextHfWeightMapper):
546548
weights = weight_mapper.preprocess_weights(weights)
547549

548-
self.mm_encoder.load_weights(weights)
550+
if self.mm_encoder is not None:
551+
self.mm_encoder.load_weights(weights)
549552

550553
def filter_weights(weights: Dict):
551554
transformed_weights = {}

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 104 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
BaseMultimodalInputProcessor, ExtraProcessedInputs,
3333
MultimodalPlaceholderMetadata,
3434
MultimodalPlaceholderPlacement, TextPrompt,
35-
register_input_processor)
35+
register_input_processor,
36+
support_multimodal_disaggregated)
3637
from ...logger import logger
3738
from ...sampling_params import SamplingParams
3839
from ..attention_backend import AttentionMetadata
@@ -865,6 +866,8 @@ def __init__(
865866
mm_encoder_config = copy.deepcopy(model_config)
866867
self.mm_encoder = Qwen2VisionModelBase(
867868
mm_encoder_config, kwargs.get('vision_model_class', None))
869+
else:
870+
self.mm_encoder = None
868871

869872
def init_mrope_embedding(self, model_config: ModelConfig[PretrainedConfig]):
870873
config = model_config.pretrained_config
@@ -953,24 +956,21 @@ def forward(
953956
"""
954957
VLM forward logic with inflight batching support.
955958
"""
956-
num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations
959+
num_context_requests = attn_metadata.num_contexts
957960

958961
multimodal_params = kwargs.get("multimodal_params", [])
959962
mm_embeds = []
960963
mrope_config = {}
961-
# NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, so we need to separate the mm_multimodal_params from the text-only prompts.
962-
mm_multimodal_params = [
963-
multimodal_param for multimodal_param in multimodal_params
964-
if multimodal_param.multimodal_data.get("image", {}).get(
965-
"pixel_values") is not None or multimodal_param.multimodal_data.
966-
get("video", {}).get("pixel_values_videos") is not None
967-
]
964+
# NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, so we need to separate
965+
# the entries that do have multimodal data from those that correspond to text-only prompts.
966+
mm_multimodal_params = self._get_requests_with_mm_data(
967+
multimodal_params)
968968
if len(mm_multimodal_params) > 0:
969969
if not _is_disagg():
970970
mm_embeds = get_multimodal_embeddings(
971971
encoder_forward_fn=self.mm_encoder.forward,
972972
multimodal_params=mm_multimodal_params)
973-
else:
973+
elif not getattr(self, "support_mm_disagg", False):
974974
raise NotImplementedError(
975975
"Qwen2VLModel does not support disaggregated inference yet. Please unset "
976976
f"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'."
@@ -995,6 +995,21 @@ def forward(
995995
logger.debug(f'output shape: {output_prob.shape}')
996996
return output_prob
997997

998+
def _get_requests_with_mm_data(self, multimodal_params):
999+
mm_multimodal_params = []
1000+
for multimodal_param in multimodal_params:
1001+
data = multimodal_param.multimodal_data
1002+
if (
1003+
# The first 2 conditions check whether there is input on which inference should be run.
1004+
data.get("image", {}).get("pixel_values") is not None or
1005+
data.get("video", {}).get("pixel_values_videos") is not None
1006+
# This condition corresponds to when the embeddings are already populated, as is e.g.
1007+
# the case in EPD disagg in the prefill worker.
1008+
or data.get("multimodal_embedding")):
1009+
mm_multimodal_params.append(multimodal_param)
1010+
1011+
return mm_multimodal_params
1012+
9981013

9991014
@register_vision_encoder(Qwen2VisionModelBase,
10001015
vlm_base_model=Qwen2VisionTransformerPretrainedModel)
@@ -1032,11 +1047,89 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper):
10321047
self.llm.load_weights(weights, weight_mapper)
10331048

10341049

1050+
class Qwen2_5VLInputProcessorBase(Qwen2VLInputProcessorBase):
1051+
1052+
def get_prompt_token_ids(
1053+
self, inputs: TextPrompt,
1054+
mm_handles: List[Dict[str,
1055+
Any]]) -> Tuple[List[int], List[int], List[int]]:
1056+
"""
1057+
Build input token ids with multimodal placeholders expanded to the number of MM tokens.
1058+
1059+
Args:
1060+
inputs: Text prompt input container. Must contain a non-empty prompt string.
1061+
mm_handles: List of multimodal embedding handles. Currently only a single handle is supported.
1062+
1063+
Returns:
1064+
Tuple[List[int], List[int], List[int]]:
1065+
- expanded_ids: token ids with each image token expanded to a placeholder repeated per MM token
1066+
- mm_token_length: per-image MM token lengths
1067+
- mm_token_offsets: start offsets (positions) for each image's MM tokens within expanded_ids
1068+
"""
1069+
# TODO: Move this function to the base input processor class when extending for more models
1070+
text_prompt = inputs.get("prompt")
1071+
if not text_prompt:
1072+
raise ValueError("Text prompt is required but not provided")
1073+
1074+
if not isinstance(mm_handles, list):
1075+
raise TypeError("mm_handles must be a list")
1076+
1077+
if len(mm_handles) != 1:
1078+
# TODO: only support single multimodal item within a request for now
1079+
raise NotImplementedError(
1080+
"Only one mm_handle is supported for Qwen2.5 VL for now")
1081+
hidden_size = mm_handles[0]['tensor_size'][1]
1082+
assert hidden_size == self.config.text_config.hidden_size, "Multimodal embedding hidden size must match model hidden size"
1083+
input_ids = self.tokenizer(text_prompt,
1084+
return_tensors="pt").input_ids[0]
1085+
1086+
image_token_index = self.config.image_token_id
1087+
1088+
image_mask = input_ids == image_token_index
1089+
image_positions = torch.where(image_mask)[0]
1090+
num_images = len(image_positions)
1091+
assert num_images == len(
1092+
mm_handles), "Number of images must match number of mm_handles"
1093+
total_mm_tokens = sum(mm_handle["tensor_size"][0]
1094+
for mm_handle in mm_handles)
1095+
final_length = len(input_ids) - num_images + total_mm_tokens
1096+
# Create output tensor
1097+
expanded_ids = torch.empty(final_length, dtype=input_ids.dtype)
1098+
placeholder_id = self.tllm_multimodal_token_id
1099+
1100+
# Fill the expanded sequence
1101+
write_pos = 0
1102+
image_cnt = 0
1103+
mm_token_length = []
1104+
mm_token_offsets = []
1105+
for read_pos in range(len(input_ids)):
1106+
if input_ids[read_pos] == image_token_index:
1107+
# Replace with placeholder id
1108+
mm_token_num = mm_handles[image_cnt]["tensor_size"][0]
1109+
expanded_ids[write_pos:write_pos + mm_token_num] = \
1110+
placeholder_id
1111+
mm_token_offsets.append(write_pos)
1112+
mm_token_length.append(mm_token_num)
1113+
write_pos += mm_token_num
1114+
image_cnt += 1
1115+
else:
1116+
# Copy text token as-is
1117+
expanded_ids[write_pos] = input_ids[read_pos]
1118+
write_pos += 1
1119+
1120+
assert write_pos == final_length, f"Write position mismatch: {write_pos} != {final_length}"
1121+
assert mm_token_length[-1] + mm_token_offsets[
1122+
-1] <= final_length, f"mm_token_length[-1] + mm_token_offsets[-1] ({mm_token_length[-1] + mm_token_offsets[-1]}) should be less than or equal to final_length ({final_length})"
1123+
return expanded_ids.to(
1124+
torch.int32).tolist(), mm_token_length, mm_token_offsets
1125+
1126+
1127+
@support_multimodal_disaggregated
10351128
@register_vision_encoder(Qwen2VisionModelBase,
10361129
vlm_base_model=Qwen2_5_VisionModel)
10371130
@register_auto_model("Qwen2_5_VLForConditionalGeneration")
10381131
@register_input_processor(
1039-
Qwen2VLInputProcessorBase,
1132+
Qwen2_5VLInputProcessorBase,
10401133
model_type="qwen2_5_vl",
10411134
placeholder_metadata=MultimodalPlaceholderMetadata(
10421135
placeholder_map={

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,8 @@ def __init__(self,
262262
chunk_size=self._chunk_size) if return_generation_logits else None
263263
self._log_probs = LogProbStorage() if return_log_probs else None
264264
self._mm_embeddings = None
265+
self._mrope_position_ids = None
266+
self._mrope_position_deltas = None
265267
self._additional_context_outputs = {
266268
name: []
267269
for name in additional_outputs
@@ -293,6 +295,16 @@ def append_mm_embeddings(self, mm_embeddings: torch.Tensor):
293295
self._mm_embeddings = SharedTensorContainer.from_tensor(
294296
mm_embeddings).dump_to_dict()
295297

298+
def set_mrope_position(
299+
self,
300+
mrope_position_ids: torch.Tensor,
301+
mrope_position_deltas: torch.Tensor,
302+
):
303+
self._mrope_position_ids = (SharedTensorContainer.from_tensor(
304+
mrope_position_ids).dump_to_dict())
305+
self._mrope_position_deltas = (SharedTensorContainer.from_tensor(
306+
mrope_position_deltas).dump_to_dict())
307+
296308
def transfer_remaining_device_logits(self):
297309
"""Finalize any remaining generation logits transfers (for chunked mode)"""
298310
if self._generation_logits:
@@ -352,6 +364,18 @@ def cum_log_probs(self) -> list[float] | None:
352364
def mm_embedding_handle(self) -> Dict[str, Any] | None:
353365
return self._mm_embeddings
354366

367+
@property
368+
def mrope_position_ids_handle(self) -> Dict[str, Any] | None:
369+
# NOTE: when populated, the returned `dict` contains the information necessary to rebuild
370+
# the `SharedTensorContainer` using the `from_dict` class method.
371+
return self._mrope_position_ids
372+
373+
@property
374+
def mrope_position_deltas_handle(self) -> Dict[str, Any] | None:
375+
# NOTE: when populated, the returned `dict` contains the information necessary to rebuild
376+
# the `SharedTensorContainer` using the `from_dict` class method.
377+
return self._mrope_position_deltas
378+
355379
@property
356380
def additional_context_outputs(self) -> Dict[str, torch.Tensor] | None:
357381
if self._additional_context_outputs is None:
@@ -382,7 +406,8 @@ class LlmResult:
382406
py_result_properties = frozenset(
383407
('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs',
384408
'mm_embedding_handle', 'additional_context_outputs',
385-
'additional_generation_outputs'))
409+
'additional_generation_outputs', 'mrope_position_ids_handle',
410+
'mrope_position_deltas_handle'))
386411

387412
def __init__(self,
388413
result: Union[bytes, tensorrt_llm.bindings.executor.Result],

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2213,13 +2213,14 @@ def _prepare_tp_inputs(
22132213
mrope_position_deltas).expand(
22142214
3, 1, 1)
22152215
mrope_position_ids.append(gen_mrope_position_ids)
2216-
multimodal_params.to_device(
2217-
"multimodal_data",
2218-
"cuda",
2219-
pin_memory=True,
2220-
target_keywords=[
2221-
"mrope_config.mrope_position_deltas"
2222-
])
2216+
if mrope_position_deltas.device.type == "cpu":
2217+
multimodal_params.to_device(
2218+
"multimodal_data",
2219+
"cuda",
2220+
pin_memory=True,
2221+
target_keywords=[
2222+
"mrope_config.mrope_position_deltas"
2223+
])
22232224
multimodal_params_list.append(multimodal_params)
22242225

22252226
request.py_batch_idx = request.py_seq_slot
@@ -2448,8 +2449,9 @@ def previous_seq_slots_device():
24482449
# NOTE: self.use_mrope is enough for differentiating whether to use mrope_position_ids but
24492450
# `_create_dummy_context_requests` from `kv_cache_creater` makes an exception that I can not add multimodal_data to the dummy_request
24502451
# so that we only replace position_ids with mrope_position_ids when it is not a dummy request and for models who is using mrope.
2451-
mrope_position_ids = torch.cat(mrope_position_ids,
2452-
dim=-1).pin_memory()
2452+
mrope_position_ids = torch.cat(mrope_position_ids, dim=-1)
2453+
if mrope_position_ids.device.type == "cpu":
2454+
mrope_position_ids = mrope_position_ids.pin_memory()
24532455
self.mrope_position_ids_cuda[:, :, :total_num_tokens].copy_(
24542456
mrope_position_ids[:, :, :total_num_tokens], non_blocking=True)
24552457
final_position_ids = self.mrope_position_ids_cuda[:, :, :
@@ -3362,7 +3364,26 @@ def _forward_step_mm_encoder_only(
33623364
mm_embeddings = list(
33633365
torch.split(mm_embeddings[0], multimodal_chunks, dim=0))
33643366

3365-
return {'mm_embeddings': mm_embeddings, 'logits': None}
3367+
# Extract mrope position data from multimodal_params if available
3368+
mrope_position_ids_list = []
3369+
mrope_position_deltas_list = []
3370+
for multimodal_param in multimodal_params:
3371+
mrope_config = multimodal_param.multimodal_data.get(
3372+
'mrope_config', {})
3373+
mrope_position_ids = mrope_config.get('mrope_position_ids')
3374+
mrope_position_deltas = mrope_config.get('mrope_position_deltas')
3375+
if mrope_position_ids is not None:
3376+
mrope_position_ids_list.append(mrope_position_ids)
3377+
if mrope_position_deltas is not None:
3378+
mrope_position_deltas_list.append(mrope_position_deltas)
3379+
3380+
result = {'mm_embeddings': mm_embeddings, 'logits': None}
3381+
if mrope_position_ids_list:
3382+
result['mrope_position_ids'] = mrope_position_ids_list
3383+
if mrope_position_deltas_list:
3384+
result['mrope_position_deltas'] = mrope_position_deltas_list
3385+
3386+
return result
33663387

33673388
def _init_userbuffers(self, hidden_size):
33683389
if self.mapping.tp_size <= 1 or self.mapping.pp_size > 1:

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from dataclasses import dataclass
2222
from functools import cached_property
2323
from itertools import repeat
24-
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, cast
24+
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, cast
2525

2626
import numpy as np
2727
import torch
@@ -199,6 +199,8 @@ def is_generation_model(self) -> bool:
199199
@dataclass(kw_only=True)
200200
class MultimodalResult:
201201
mm_embeddings: List[torch.Tensor]
202+
# Can be used to include e.g. `mrope_position_ids`, etc.
203+
extra_data: Optional[Dict[str, Any]] = None
202204

203205
def values(self):
204206
return vars(self).values()
@@ -262,7 +264,10 @@ def sample_async(
262264
resource_manager: Optional[ResourceManager] = None,
263265
) -> SampleStateWithMMResult:
264266
# from model_outputs to MultimodalResult
265-
data = MultimodalResult(mm_embeddings=model_outputs["mm_embeddings"])
267+
data = MultimodalResult(
268+
mm_embeddings=model_outputs.pop("mm_embeddings"),
269+
extra_data={**model_outputs},
270+
)
266271
return SampleStateWithMMResult(scheduled_requests=scheduled_requests, data=data)
267272

268273
@override
@@ -276,7 +281,12 @@ def update_requests(
276281
scheduled_requests = state.scheduled_requests
277282
assert not scheduled_requests.generation_requests
278283
mm_embeddings = state.data.mm_embeddings
279-
for request, mm_embedding in zip(scheduled_requests.context_requests, mm_embeddings):
284+
extra_data = state.data.extra_data or {}
285+
mrope_position_ids = extra_data.get("mrope_position_ids", None)
286+
mrope_position_deltas = extra_data.get("mrope_position_deltas", None)
287+
for i, (request, mm_embedding) in enumerate(
288+
zip(scheduled_requests.context_requests, mm_embeddings)
289+
):
280290
request.state = LlmRequestState.GENERATION_COMPLETE
281291
# NOTE: This is a hack: set finish reason manually and set the beam 0
282292
request.set_finished_reason(FinishReason.LENGTH, 0)
@@ -287,6 +297,12 @@ def update_requests(
287297

288298
request.py_result.append_mm_embeddings(mm_embedding)
289299

300+
# Store mrope data if available
301+
if mrope_position_ids is not None and mrope_position_deltas is not None:
302+
request.py_result.set_mrope_position(
303+
mrope_position_ids[i], mrope_position_deltas[i]
304+
)
305+
290306
@override
291307
def is_generation_model(self) -> bool:
292308
return False

tensorrt_llm/disaggregated_params.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class DisaggregatedParams:
4040
multimodal_hashes: Optional[List[List[int]]] = (
4141
None # user provided mm hashes should be a list of 8 integers
4242
)
43+
mrope_position_ids_handle: Optional[Dict[str, Any]] = None
44+
mrope_position_deltas_handle: Optional[Dict[str, Any]] = None
4345

4446
def get_context_phase_params(self) -> tllme.ContextPhaseParams:
4547
return tllme.ContextPhaseParams(

0 commit comments

Comments
 (0)