1+ import gc
2+ import os
3+ import pathlib
4+ import pickle
5+ import sys
6+
7+ import cloudpickle
8+ import mpi4py
19import pytest
210import torch
311import transformers
412from transformers .models .pixtral import modeling_pixtral as hf_modeling_pixtral
513
14+ import tensorrt_llm
615from tensorrt_llm import mapping as mapping_lib
716from tensorrt_llm ._torch import model_config as model_config_lib
817from tensorrt_llm ._torch .models import modeling_pixtral
918
19+ sys .path .append (os .path .join (os .path .dirname (__file__ ), ".." ))
20+ cloudpickle .register_pickle_by_value (sys .modules [__name__ ])
21+ mpi4py .MPI .pickle .__init__ (
22+ cloudpickle .dumps ,
23+ cloudpickle .loads ,
24+ pickle .HIGHEST_PROTOCOL ,
25+ )
26+
27+ # needed since we reuse the mpi executor pool, first test running will leak a thread
28+ pytestmark = pytest .mark .threadleak (enabled = False )
29+
1030
1131@pytest .fixture
1232def pixtral_vision_config ():
@@ -49,21 +69,6 @@ def init_hf_model(cls, config, dtype, device):
4969 return model
5070
5171
52- @pytest .mark .parametrize (
53- "mapping" ,
54- [
55- mapping_lib .Mapping (world_size = 2 , tp_size = 2 ),
56- mapping_lib .Mapping (world_size = 3 , tp_size = 3 ),
57- mapping_lib .Mapping (world_size = 4 , tp_size = 2 , pp_size = 2 ),
58- mapping_lib .Mapping (world_size = 8 , tp_size = 2 , pp_size = 2 , cp_size = 2 ),
59- ],
60- )
61- def test_pixtral_vision_model_rejects_tp_size_greater_than_one (pixtral_vision_config , mapping ):
62- pixtral_vision_config .mapping = mapping
63- with pytest .raises (NotImplementedError , match = "tp_size > 1" ):
64- modeling_pixtral .PixtralVisionModel (model_config = pixtral_vision_config )
65-
66-
6772@torch .no_grad ()
6873@pytest .mark .usefixtures ("set_seed" )
6974def test_pixtral_vision_model_vs_hf (pixtral_vision_config ):
@@ -83,10 +88,10 @@ def test_pixtral_vision_model_vs_hf(pixtral_vision_config):
8388 # Make sure both models have the same weights.
8489 pixtral_model .load_weights (hf_pixtral_model .state_dict ())
8590
86- batch_size = 1
91+ batch_size = 2
8792 height , width , channels = 123 , 456 , 3
8893 pixel_values = torch .randn (batch_size , channels , height , width , device = device , dtype = dtype )
89- image_sizes = torch .tensor ([[height , width ]])
94+ image_sizes = torch .tensor ([[height , width ], [ height - 7 , width - 11 ] ])
9095 out = pixtral_model (
9196 pixel_values = pixel_values ,
9297 image_sizes = image_sizes ,
@@ -102,3 +107,112 @@ def test_pixtral_vision_model_vs_hf(pixtral_vision_config):
102107 )
103108
104109 torch .testing .assert_close (out , hf_out , atol = 0.2 , rtol = 0.2 )
110+
111+
112+ @pytest .mark .parametrize ("mpi_pool_executor" , [2 ], indirect = True )
113+ @torch .no_grad ()
114+ def test_tensor_parallelism (pixtral_vision_config , mpi_pool_executor , tmp_path ):
115+ mapping = mapping_lib .Mapping (world_size = 2 , tp_size = 2 )
116+ if (num_available_devices := torch .cuda .device_count ()) < mapping .world_size :
117+ pytest .skip (f"{ num_available_devices = } is less than the requested { mapping .world_size } ." )
118+
119+ dtype = torch .bfloat16
120+ device = torch .device ("cuda" )
121+ pretrained_config = pixtral_vision_config .pretrained_config
122+
123+ hf_pixtral_model = init_hf_model (
124+ cls = hf_modeling_pixtral .PixtralVisionModel ,
125+ config = pretrained_config ,
126+ dtype = dtype ,
127+ device = device ,
128+ )
129+ # Save HF weights to disk so they can be used by worker processes.
130+ state_dict = hf_pixtral_model .state_dict ()
131+ hf_weights_path = tmp_path / "hf_weights.pt"
132+ torch .save (state_dict , hf_weights_path )
133+
134+ pixtral_model = (
135+ modeling_pixtral .PixtralVisionModel (model_config = pixtral_vision_config ).eval ().to ("cuda" )
136+ )
137+ pixtral_model .load_weights (state_dict )
138+ # Save the number of params to check that the model gets shared in the workers.
139+ num_params = sum (p .numel () for p in pixtral_model .parameters ())
140+
141+ batch_size = 2
142+ height , width , channels = 123 , 456 , 3
143+ pixel_values = torch .randn (batch_size , channels , height , width , device = device , dtype = dtype )
144+ image_sizes = torch .tensor ([[height , width ], [height - 7 , width - 11 ]])
145+
146+ ref_out = pixtral_model (pixel_values = pixel_values , image_sizes = image_sizes )
147+
148+ # Move to CPU before sending across process barrier.
149+ ref_out = ref_out .to ("cpu" )
150+ pixel_values = pixel_values .to ("cpu" )
151+ image_sizes = image_sizes .to ("cpu" )
152+
153+ # Free up GPU memory on rank 0.
154+ del state_dict
155+ del hf_pixtral_model
156+ del pixtral_model
157+ gc .collect ()
158+ torch .cuda .empty_cache ()
159+
160+ world_size = mapping .world_size
161+ pixtral_vision_config .mapping = mapping
162+ results = mpi_pool_executor .starmap (
163+ _run_pixtral_and_compare_against_ref ,
164+ [
165+ (
166+ pixtral_vision_config ,
167+ hf_weights_path ,
168+ pixel_values ,
169+ image_sizes ,
170+ ref_out ,
171+ num_params ,
172+ )
173+ for _ in range (world_size )
174+ ],
175+ )
176+
177+ for r in results :
178+ assert r
179+
180+
181+ def _run_pixtral_and_compare_against_ref (
182+ pixtral_vision_config : model_config_lib .ModelConfig [transformers .PixtralVisionConfig ],
183+ hf_weights_path : pathlib .Path ,
184+ pixel_values : torch .Tensor ,
185+ image_sizes : torch .Tensor ,
186+ expected_output : torch .Tensor ,
187+ total_num_params : int ,
188+ ) -> bool :
189+ rank = tensorrt_llm .mpi_rank ()
190+ # Smoke check.
191+ world_size = tensorrt_llm .mpi_world_size ()
192+ assert world_size > 1
193+
194+ torch .cuda .set_device (rank )
195+
196+ pixel_values = pixel_values .to ("cuda" )
197+ image_sizes = image_sizes .to ("cuda" )
198+ expected_output = expected_output .to ("cuda" )
199+
200+ pixtral_vision_config .mapping .rank = rank
201+ pixtral_model = (
202+ modeling_pixtral .PixtralVisionModel (model_config = pixtral_vision_config ).eval ().to ("cuda" )
203+ )
204+ state_dict = torch .load (hf_weights_path , map_location = "cuda" )
205+ pixtral_model .load_weights (state_dict )
206+
207+ # Smoke check to see that we are indeed sharding the model.
208+ rank_num_params = sum (p .numel () for p in pixtral_model .parameters ())
209+ params_fraction = rank_num_params / total_num_params
210+ assert params_fraction < 1.0
211+ assert params_fraction == pytest .approx (1.0 / world_size , rel = 1e-2 )
212+
213+ out = pixtral_model (
214+ pixel_values = pixel_values ,
215+ image_sizes = image_sizes ,
216+ )
217+ torch .testing .assert_close (out , expected_output , atol = 0.2 , rtol = 0.2 )
218+ return True
0 commit comments