|
25 | 25 | MultimodalPlaceholderPlacement, |
26 | 26 | TextPrompt, |
27 | 27 | register_input_processor, |
| 28 | + support_multimodal_disaggregated, |
28 | 29 | ) |
29 | 30 | from ...inputs.multimodal import MultimodalParams |
30 | 31 | from ...logger import logger |
@@ -350,6 +351,85 @@ def __call__( |
350 | 351 | "multimodal_data": multimodal_data, |
351 | 352 | } |
352 | 353 |
|
| 354 | + def get_prompt_token_ids( |
| 355 | + self, inputs: TextPrompt, mm_handles: List[Dict[str, Any]] |
| 356 | + ) -> Tuple[List[int], List[int], List[int]]: |
| 357 | + """ |
| 358 | + Build input token ids with multimodal placeholders expanded to the number of MM tokens. |
| 359 | +
|
| 360 | + Args: |
| 361 | + inputs: Text prompt input container. Must contain a non-empty prompt string. |
| 362 | + mm_handles: List of multimodal embedding handles. Currently only a single handle is supported. |
| 363 | +
|
| 364 | + Returns: |
| 365 | + Tuple[List[int], List[int], List[int]]: |
| 366 | + - expanded_ids: token ids with each image token expanded to a placeholder repeated per MM token |
| 367 | + - mm_token_length: per-image MM token lengths |
| 368 | + - mm_token_offsets: start offsets (positions) for each image's MM tokens within expanded_ids |
| 369 | + """ |
| 370 | + # TODO: Move this function to the base input processor class when extending for more models |
| 371 | + text_prompt = inputs.get("prompt") |
| 372 | + if not text_prompt: |
| 373 | + raise ValueError("Text prompt is required but not provided") |
| 374 | + |
| 375 | + if not isinstance(mm_handles, list): |
| 376 | + raise TypeError("mm_handles must be a list") |
| 377 | + |
| 378 | + if len(mm_handles) > 1: |
| 379 | + # TODO: only support single multimodal item within a request for now |
| 380 | + raise NotImplementedError("Only one mm_handle is supported for Qwen3 VL for now") |
| 381 | + |
| 382 | + hidden_size = mm_handles[0]["tensor_size"][1] |
| 383 | + num_deepstack_levels = len(self.config.vision_config.deepstack_visual_indexes) |
| 384 | + # This is because, unlike previous Qwen VL models, the embeddings are concatenated with |
| 385 | + # feature maps from deepstack layers. |
| 386 | + expected_size = self.config.text_config.hidden_size * (1 + num_deepstack_levels) |
| 387 | + if hidden_size != expected_size: |
| 388 | + raise RuntimeError( |
| 389 | + f"Expected multimodal embedding to have hidden size {expected_size}, got {hidden_size}." |
| 390 | + ) |
| 391 | + |
| 392 | + input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids[0] |
| 393 | + |
| 394 | + # TODO: what about `video_token_id`? |
| 395 | + image_token_index = self.config.image_token_id |
| 396 | + |
| 397 | + image_mask = input_ids == image_token_index |
| 398 | + image_positions = torch.where(image_mask)[0] |
| 399 | + num_images = len(image_positions) |
| 400 | + assert num_images == len(mm_handles), "Number of images must match number of mm_handles" |
| 401 | + total_mm_tokens = sum(mm_handle["tensor_size"][0] for mm_handle in mm_handles) |
| 402 | + final_length = len(input_ids) - num_images + total_mm_tokens |
| 403 | + # Create output tensor |
| 404 | + expanded_ids = torch.empty(final_length, dtype=input_ids.dtype) |
| 405 | + placeholder_id = self.tllm_multimodal_token_id |
| 406 | + |
| 407 | + # Fill the expanded sequence |
| 408 | + write_pos = 0 |
| 409 | + image_cnt = 0 |
| 410 | + mm_token_length = [] |
| 411 | + mm_token_offsets = [] |
| 412 | + for read_pos in range(len(input_ids)): |
| 413 | + if input_ids[read_pos] == image_token_index: |
| 414 | + # Replace with placeholder id |
| 415 | + mm_token_num = mm_handles[image_cnt]["tensor_size"][0] |
| 416 | + expanded_ids[write_pos : write_pos + mm_token_num] = placeholder_id |
| 417 | + mm_token_offsets.append(write_pos) |
| 418 | + mm_token_length.append(mm_token_num) |
| 419 | + write_pos += mm_token_num |
| 420 | + image_cnt += 1 |
| 421 | + else: |
| 422 | + # Copy text token as-is |
| 423 | + expanded_ids[write_pos] = input_ids[read_pos] |
| 424 | + write_pos += 1 |
| 425 | + |
| 426 | + assert write_pos == final_length, f"Write position mismatch: {write_pos} != {final_length}" |
| 427 | + assert mm_token_length[-1] + mm_token_offsets[-1] <= final_length, ( |
| 428 | + f"mm_token_length[-1] + mm_token_offsets[-1] ({mm_token_length[-1] + mm_token_offsets[-1]}) should be less " |
| 429 | + f"than or equal to final_length ({final_length})" |
| 430 | + ) |
| 431 | + return expanded_ids.to(torch.int32).tolist(), mm_token_length, mm_token_offsets |
| 432 | + |
353 | 433 |
|
354 | 434 | class Qwen3VLVisionAttention(Qwen2_5_VLVisionAttention): |
355 | 435 | def __init__(self, model_config, layer_idx): |
@@ -825,6 +905,7 @@ def __init__( |
825 | 905 | llm_model_config.pretrained_config.architectures = ["Qwen3MoeForCausalLM"] |
826 | 906 | else: |
827 | 907 | raise ValueError(f"Unsupported architecture: {self.original_arch}") |
| 908 | + # Qwen3ForCausalLM. |
828 | 909 | self.llm = AutoModelForCausalLM.from_config(llm_model_config) |
829 | 910 |
|
830 | 911 | if not _is_disagg(): |
@@ -953,22 +1034,16 @@ def forward( |
953 | 1034 |
|
954 | 1035 | # NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, |
955 | 1036 | # so we need to separate the mm_multimodal_params from the text-only prompts. |
956 | | - mm_multimodal_params = [ |
957 | | - multimodal_param |
958 | | - for multimodal_param in multimodal_params |
959 | | - if multimodal_param.multimodal_data.get("image", {}).get("pixel_values") is not None |
960 | | - or multimodal_param.multimodal_data.get("video", {}).get("pixel_values_videos") |
961 | | - is not None |
962 | | - ] |
| 1037 | + mm_multimodal_params = self._get_requests_with_mm_data(multimodal_params) |
963 | 1038 | if len(mm_multimodal_params) > 0: |
964 | 1039 | if not _is_disagg(): |
965 | 1040 | mm_embeds = get_multimodal_embeddings( |
966 | 1041 | encoder_forward_fn=self.mm_encoder.forward, |
967 | 1042 | multimodal_params=mm_multimodal_params, |
968 | 1043 | ) |
969 | | - else: |
| 1044 | + elif not getattr(self, "support_mm_disagg", False): |
970 | 1045 | raise NotImplementedError( |
971 | | - "Qwen3VLModel does not support disaggregated inference yet. Please unset " |
| 1046 | + f"{type(self)} does not support disaggregated inference yet. Please unset " |
972 | 1047 | "the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'." |
973 | 1048 | ) |
974 | 1049 | mm_embeds = find_input_mm_embeds(mm_embeds, mm_multimodal_params) |
@@ -1008,7 +1083,24 @@ def forward( |
1008 | 1083 | logger.debug(f"output shape: {output_prob.shape}") |
1009 | 1084 | return output_prob |
1010 | 1085 |
|
| 1086 | + def _get_requests_with_mm_data(self, multimodal_params): |
| 1087 | + mm_multimodal_params = [] |
| 1088 | + for multimodal_param in multimodal_params: |
| 1089 | + data = multimodal_param.multimodal_data |
| 1090 | + if ( |
| 1091 | + # The first 2 conditions check whether there is input on which inference should be run. |
| 1092 | + data.get("image", {}).get("pixel_values") is not None |
| 1093 | + or data.get("video", {}).get("pixel_values_videos") is not None |
| 1094 | + # This condition corresponds to when the embeddings are already populated, as is e.g. |
| 1095 | + # the case in EPD disagg in the prefill worker. |
| 1096 | + or data.get("multimodal_embedding") |
| 1097 | + ): |
| 1098 | + mm_multimodal_params.append(multimodal_param) |
| 1099 | + |
| 1100 | + return mm_multimodal_params |
| 1101 | + |
1011 | 1102 |
|
| 1103 | +@support_multimodal_disaggregated |
1012 | 1104 | @register_vision_encoder(Qwen3VisionModelBase, vlm_base_model=Qwen3VisionModel) |
1013 | 1105 | @register_auto_model("Qwen3VLForConditionalGeneration") |
1014 | 1106 | @register_input_processor( |
|
0 commit comments