Skip to content

Commit ab7434a

Browse files
authored
[feat] Enable TP and batching for PixtralVisionModel / Mistral3VLM (NVIDIA#6152)
Signed-off-by: William Zhang <[email protected]>
1 parent b7c8a67 commit ab7434a

File tree

4 files changed

+195
-50
lines changed

4 files changed

+195
-50
lines changed

tensorrt_llm/_torch/models/modeling_clip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def prepare_attn_metadata(self, batch_size):
202202
request_ids=request_ids,
203203
prompt_lens=prompt_lens,
204204
)
205-
attn_metadata.max_seq_len = seq_len * batch_size
205+
attn_metadata.max_seq_len = seq_len
206206
attn_metadata.prepare()
207207
return attn_metadata
208208

tensorrt_llm/_torch/models/modeling_mistral.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Dict, List, Optional, Tuple
44

55
import torch
6+
import torchvision
67
from torch import nn
78
from transformers import (AutoProcessor, AutoTokenizer, Mistral3Config,
89
MistralConfig, PretrainedConfig, PreTrainedModel)
@@ -347,7 +348,6 @@ def forward(
347348
attn_metadata: AttentionMetadata,
348349
input_ids: Optional[torch.LongTensor] = None,
349350
position_ids: Optional[torch.LongTensor] = None,
350-
inputs_embeds: Optional[torch.Tensor] = None,
351351
return_context_logits: bool = False,
352352
**kwargs,
353353
) -> torch.Tensor:
@@ -363,20 +363,26 @@ def forward(
363363
raise RuntimeError(
364364
f"Number of multimodal tensors ({multimodal_params_len}) should be equal to number of "
365365
f"context requests ({num_context_requests}) in the batch.")
366-
# NOTES:
367-
# 1. the pixel values in `multimodal_data["image"]` might vary in (height, width) between
368-
# images, making them unsafe to batch in general. The input processor also cannot produce
369-
# them in a batch, since it is always called with a single input - otherwise, we would
370-
# have been able to naturally leverage the padding / resizing capabilities of the underlying
371-
# `PixtralProcessor`.
372-
# 2. After each `pixel_values` tensor has gone through the vision tower's `patch_conv` layer,
373-
# they are divided into patches that are then concatenated in order to treat them as a
374-
# single "sequence" in the vision tower's attention layers, so some form of batching still
375-
# happens in the vision tower.
376-
image_features = [
377-
self._get_image_features(**x.multimodal_data["image"])
366+
pixel_values = [
367+
x.multimodal_data["image"]["pixel_values"]
368+
for x in multimodal_params
369+
]
370+
image_sizes = [
371+
x.multimodal_data["image"]["image_sizes"]
378372
for x in multimodal_params
379373
]
374+
if not (len(pixel_values) == len(image_sizes) ==
375+
multimodal_params_len):
376+
raise ValueError(
377+
f"Expected as many `pixel_values` ({len(pixel_values)}) and "
378+
f"`image_sizes` ({len(image_sizes)}) as number of multimodal parameters "
379+
f"({multimodal_params_len}).")
380+
batched_pixel_values, batched_image_sizes = self._batch_pixel_values(
381+
pixel_values=pixel_values, image_sizes=image_sizes)
382+
image_features = [
383+
self._get_image_features(pixel_values=batched_pixel_values,
384+
image_sizes=batched_image_sizes)
385+
]
380386

381387
input_ids, inputs_embeds = fuse_input_embeds(
382388
embedding_layer=self.llm.model.embed_tokens,
@@ -429,6 +435,31 @@ def _get_image_features(
429435
image_sizes)
430436
return image_features
431437

438+
# Original HF implementation:
439+
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/pixtral/
440+
# image_processing_pixtral.py#L276
441+
# We switch to using torchvision's padding functionality since it supports torch tensors
442+
# (the transformers one expected numpy arrays).
443+
@staticmethod
444+
@torch.inference_mode()
445+
def _batch_pixel_values(
446+
pixel_values: List[torch.Tensor],
447+
image_sizes: List[torch.Tensor],
448+
) -> tuple[torch.Tensor, torch.Tensor]:
449+
batched_image_sizes = torch.cat(image_sizes)
450+
max_shape = batched_image_sizes.max(dim=0).values
451+
pixel_values = [
452+
torchvision.transforms.v2.functional.pad(
453+
image,
454+
# Per torchvision docs, this should be in LTRB order if it's a sequence of 4 numbers.
455+
padding=[0, 0, max_shape[1] - size[1], max_shape[0] - size[0]],
456+
# Values extracted from HF implementation.
457+
fill=0.0,
458+
padding_mode="constant",
459+
) for image, size in zip(pixel_values, batched_image_sizes)
460+
]
461+
return torch.cat(pixel_values), batched_image_sizes
462+
432463

433464
# Original implementation:
434465
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral3/modeling_mistral3.py#L66

tensorrt_llm/_torch/models/modeling_pixtral.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,18 @@ def forward(
106106
class PixtralTransformer(torch.nn.Module):
107107
def __init__(self, config: model_config_lib.ModelConfig[transformers.PixtralVisionConfig]):
108108
super().__init__()
109+
tp_size = config.mapping.tp_size
110+
num_heads = config.pretrained_config.num_attention_heads
111+
if (num_heads % tp_size) > 0:
112+
raise ValueError(f"{tp_size=} must divide {num_heads=}.")
113+
num_heads //= tp_size
114+
115+
self._head_dim = config.pretrained_config.head_dim
116+
self._num_heads = num_heads
117+
109118
self.layers = torch.nn.ModuleList()
110119
for i in range(config.pretrained_config.num_hidden_layers):
111120
self.layers.append(PixtralAttentionLayer(config=config, layer_idx=i))
112-
self._head_dim = config.pretrained_config.head_dim
113-
self._num_heads = config.pretrained_config.num_attention_heads
114121

115122
def forward(
116123
self,
@@ -165,12 +172,6 @@ def __init__(
165172
self, model_config: model_config_lib.ModelConfig[transformers.PixtralVisionConfig]
166173
):
167174
super().__init__()
168-
tp_size = model_config.mapping.tp_size
169-
# TODO: implement support for `tp_size > 1`.
170-
if tp_size > 1:
171-
raise NotImplementedError(
172-
f"Mistral3VLM does not support `mapping.tp_size > 1` yet (got {tp_size})."
173-
)
174175
# Both the below are needed in order to use `_load_weights_impl`.
175176
self.model_config = model_config
176177
self.config: transformers.PixtralVisionConfig = model_config.pretrained_config
@@ -204,12 +205,14 @@ def forward(
204205
):
205206
with torch.autocast(device_type="cuda", dtype=self.config.torch_dtype):
206207
patch_embeds = self.patch_conv(pixel_values)
208+
207209
patch_embeds_list = [
208210
embed[..., : (size[0] // self._patch_size), : (size[1] // self._patch_size)]
209211
for embed, size in zip(patch_embeds, image_sizes)
210212
]
211213

212-
patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0)
214+
flattened_embeds = [p.flatten(1).T for p in patch_embeds_list]
215+
patch_embeds = torch.cat(flattened_embeds, dim=0)
213216
patch_embeds = self.ln_pre(patch_embeds)
214217

215218
position_ids = transformers.models.pixtral.modeling_pixtral.position_ids_in_meshgrid(
@@ -218,10 +221,8 @@ def forward(
218221
position_embeddings = self._patch_positional_embedding(patch_embeds, position_ids)
219222

220223
attn_metadata = self._prepare_attn_metadata(
221-
# The `torch.cat` that creates the `patch_embeds` flattens the conv features from multiple
222-
# images into a single sequence - hence why we hardcode the batch size to 1 here.
223-
batch_size=1,
224-
seq_len=position_ids.size(0),
224+
batch_size=pixel_values.size(0),
225+
seq_lengths=[x.size(0) for x in flattened_embeds],
225226
)
226227
out = self.transformer(
227228
patch_embeds,
@@ -235,19 +236,18 @@ def forward(
235236
def load_weights(self, weights):
236237
modeling_utils._load_weights_impl(self, weights)
237238

238-
def _prepare_attn_metadata(self, batch_size: int, seq_len: int):
239+
def _prepare_attn_metadata(self, batch_size: int, seq_lengths: List[int]):
239240
request_ids = list(range(1, batch_size + 1))
240-
prompt_lens = [seq_len] * batch_size
241241
attn_metadata = self._metadata_cls(
242-
seq_lens=torch.tensor([seq_len] * batch_size, dtype=torch.int),
242+
seq_lens=torch.tensor(seq_lengths, dtype=torch.int),
243243
num_contexts=batch_size,
244244
max_num_requests=batch_size,
245-
max_num_tokens=seq_len * batch_size,
245+
max_num_tokens=sum(seq_lengths),
246246
kv_cache_manager=None,
247247
request_ids=request_ids,
248-
prompt_lens=prompt_lens,
248+
prompt_lens=seq_lengths,
249249
)
250-
attn_metadata.max_seq_len = seq_len * batch_size
250+
attn_metadata.max_seq_len = max(seq_lengths)
251251
attn_metadata.prepare()
252252
return attn_metadata
253253

tests/unittest/_torch/modeling/test_modeling_pixtral.py

Lines changed: 131 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,32 @@
1+
import gc
2+
import os
3+
import pathlib
4+
import pickle
5+
import sys
6+
7+
import cloudpickle
8+
import mpi4py
19
import pytest
210
import torch
311
import transformers
412
from transformers.models.pixtral import modeling_pixtral as hf_modeling_pixtral
513

14+
import tensorrt_llm
615
from tensorrt_llm import mapping as mapping_lib
716
from tensorrt_llm._torch import model_config as model_config_lib
817
from tensorrt_llm._torch.models import modeling_pixtral
918

19+
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
20+
cloudpickle.register_pickle_by_value(sys.modules[__name__])
21+
mpi4py.MPI.pickle.__init__(
22+
cloudpickle.dumps,
23+
cloudpickle.loads,
24+
pickle.HIGHEST_PROTOCOL,
25+
)
26+
27+
# needed since we reuse the mpi executor pool, first test running will leak a thread
28+
pytestmark = pytest.mark.threadleak(enabled=False)
29+
1030

1131
@pytest.fixture
1232
def pixtral_vision_config():
@@ -49,21 +69,6 @@ def init_hf_model(cls, config, dtype, device):
4969
return model
5070

5171

52-
@pytest.mark.parametrize(
53-
"mapping",
54-
[
55-
mapping_lib.Mapping(world_size=2, tp_size=2),
56-
mapping_lib.Mapping(world_size=3, tp_size=3),
57-
mapping_lib.Mapping(world_size=4, tp_size=2, pp_size=2),
58-
mapping_lib.Mapping(world_size=8, tp_size=2, pp_size=2, cp_size=2),
59-
],
60-
)
61-
def test_pixtral_vision_model_rejects_tp_size_greater_than_one(pixtral_vision_config, mapping):
62-
pixtral_vision_config.mapping = mapping
63-
with pytest.raises(NotImplementedError, match="tp_size > 1"):
64-
modeling_pixtral.PixtralVisionModel(model_config=pixtral_vision_config)
65-
66-
6772
@torch.no_grad()
6873
@pytest.mark.usefixtures("set_seed")
6974
def test_pixtral_vision_model_vs_hf(pixtral_vision_config):
@@ -83,10 +88,10 @@ def test_pixtral_vision_model_vs_hf(pixtral_vision_config):
8388
# Make sure both models have the same weights.
8489
pixtral_model.load_weights(hf_pixtral_model.state_dict())
8590

86-
batch_size = 1
91+
batch_size = 2
8792
height, width, channels = 123, 456, 3
8893
pixel_values = torch.randn(batch_size, channels, height, width, device=device, dtype=dtype)
89-
image_sizes = torch.tensor([[height, width]])
94+
image_sizes = torch.tensor([[height, width], [height - 7, width - 11]])
9095
out = pixtral_model(
9196
pixel_values=pixel_values,
9297
image_sizes=image_sizes,
@@ -102,3 +107,112 @@ def test_pixtral_vision_model_vs_hf(pixtral_vision_config):
102107
)
103108

104109
torch.testing.assert_close(out, hf_out, atol=0.2, rtol=0.2)
110+
111+
112+
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
113+
@torch.no_grad()
114+
def test_tensor_parallelism(pixtral_vision_config, mpi_pool_executor, tmp_path):
115+
mapping = mapping_lib.Mapping(world_size=2, tp_size=2)
116+
if (num_available_devices := torch.cuda.device_count()) < mapping.world_size:
117+
pytest.skip(f"{num_available_devices=} is less than the requested {mapping.world_size}.")
118+
119+
dtype = torch.bfloat16
120+
device = torch.device("cuda")
121+
pretrained_config = pixtral_vision_config.pretrained_config
122+
123+
hf_pixtral_model = init_hf_model(
124+
cls=hf_modeling_pixtral.PixtralVisionModel,
125+
config=pretrained_config,
126+
dtype=dtype,
127+
device=device,
128+
)
129+
# Save HF weights to disk so they can be used by worker processes.
130+
state_dict = hf_pixtral_model.state_dict()
131+
hf_weights_path = tmp_path / "hf_weights.pt"
132+
torch.save(state_dict, hf_weights_path)
133+
134+
pixtral_model = (
135+
modeling_pixtral.PixtralVisionModel(model_config=pixtral_vision_config).eval().to("cuda")
136+
)
137+
pixtral_model.load_weights(state_dict)
138+
# Save the number of params to check that the model gets shared in the workers.
139+
num_params = sum(p.numel() for p in pixtral_model.parameters())
140+
141+
batch_size = 2
142+
height, width, channels = 123, 456, 3
143+
pixel_values = torch.randn(batch_size, channels, height, width, device=device, dtype=dtype)
144+
image_sizes = torch.tensor([[height, width], [height - 7, width - 11]])
145+
146+
ref_out = pixtral_model(pixel_values=pixel_values, image_sizes=image_sizes)
147+
148+
# Move to CPU before sending across process barrier.
149+
ref_out = ref_out.to("cpu")
150+
pixel_values = pixel_values.to("cpu")
151+
image_sizes = image_sizes.to("cpu")
152+
153+
# Free up GPU memory on rank 0.
154+
del state_dict
155+
del hf_pixtral_model
156+
del pixtral_model
157+
gc.collect()
158+
torch.cuda.empty_cache()
159+
160+
world_size = mapping.world_size
161+
pixtral_vision_config.mapping = mapping
162+
results = mpi_pool_executor.starmap(
163+
_run_pixtral_and_compare_against_ref,
164+
[
165+
(
166+
pixtral_vision_config,
167+
hf_weights_path,
168+
pixel_values,
169+
image_sizes,
170+
ref_out,
171+
num_params,
172+
)
173+
for _ in range(world_size)
174+
],
175+
)
176+
177+
for r in results:
178+
assert r
179+
180+
181+
def _run_pixtral_and_compare_against_ref(
182+
pixtral_vision_config: model_config_lib.ModelConfig[transformers.PixtralVisionConfig],
183+
hf_weights_path: pathlib.Path,
184+
pixel_values: torch.Tensor,
185+
image_sizes: torch.Tensor,
186+
expected_output: torch.Tensor,
187+
total_num_params: int,
188+
) -> bool:
189+
rank = tensorrt_llm.mpi_rank()
190+
# Smoke check.
191+
world_size = tensorrt_llm.mpi_world_size()
192+
assert world_size > 1
193+
194+
torch.cuda.set_device(rank)
195+
196+
pixel_values = pixel_values.to("cuda")
197+
image_sizes = image_sizes.to("cuda")
198+
expected_output = expected_output.to("cuda")
199+
200+
pixtral_vision_config.mapping.rank = rank
201+
pixtral_model = (
202+
modeling_pixtral.PixtralVisionModel(model_config=pixtral_vision_config).eval().to("cuda")
203+
)
204+
state_dict = torch.load(hf_weights_path, map_location="cuda")
205+
pixtral_model.load_weights(state_dict)
206+
207+
# Smoke check to see that we are indeed sharding the model.
208+
rank_num_params = sum(p.numel() for p in pixtral_model.parameters())
209+
params_fraction = rank_num_params / total_num_params
210+
assert params_fraction < 1.0
211+
assert params_fraction == pytest.approx(1.0 / world_size, rel=1e-2)
212+
213+
out = pixtral_model(
214+
pixel_values=pixel_values,
215+
image_sizes=image_sizes,
216+
)
217+
torch.testing.assert_close(out, expected_output, atol=0.2, rtol=0.2)
218+
return True

0 commit comments

Comments
 (0)