2020Strategy: Distribute whole images across DP ranks, not patches within images.
2121This avoids breaking cu_seqlens semantics while parallelizing ViT computation.
2222
23- Key difference from text SP:
24- - Text SP: Split sequence within attention layers, all-to-all per layer
25- - Vision DP: Split images across ranks, all_gather once at the end
23+ Key design choices:
24+ - Image-level distribution (not patch-level): avoids breaking ViT's internal
25+ cu_seqlens tracking
26+ - Contiguous assignment: rank 0 gets images [0,1,...], rank 1 gets next chunk, etc.
27+ No reordering needed after all-gather.
28+ - Gradient sync in backward: all_reduce(SUM) across SP ranks before slicing to
29+ recover the complete gradient for each image. Without this, gradients from
30+ vision tokens in other ranks' sequence shards would be lost.
31+ - No additional gradient scaling needed: the all_reduce aggregates partial
32+ sequence gradients, making each rank's ViT backward equivalent to the non-DP
33+ baseline. FSDP's dp_sp reduce-scatter then handles DP averaging as usual.
2634"""
2735
2836import torch
@@ -70,10 +78,12 @@ def assign_images_to_dp_ranks(
7078 patch_counts : list [int ],
7179 dp_size : int ,
7280) -> tuple [list [list [int ]], list [int ]]:
73- """Assign whole images to DP ranks using contiguous distribution.
81+ """Assign whole images to DP ranks using load-balanced contiguous distribution.
7482
75- Rank 0 gets images [0, 1, ...], rank 1 gets next chunk, etc.
76- This ensures no reordering is needed after all-gather.
83+ The algorithm uses greedy contiguous bin-packing:
84+ - Images are assigned in order (contiguous) to preserve ordering after gather
85+ - Split points are chosen to balance total patch load across ranks
86+ - Each rank gets at least one image when num_images >= dp_size
7787
7888 Args:
7989 patch_counts: Number of patches per image.
@@ -91,17 +101,34 @@ def assign_images_to_dp_ranks(
91101 image_assignments : list [list [int ]] = [[] for _ in range (dp_size )]
92102 rank_loads = [0 ] * dp_size
93103
94- base_size = num_images // dp_size
95- remainder = num_images % dp_size
96-
97- start = 0
104+ remaining_patches = sum (patch_counts )
105+ img_idx = 0
98106 for rank in range (dp_size ):
99- chunk_size = base_size + (1 if rank < remainder else 0 )
100- end = start + chunk_size
101- for img_idx in range (start , end ):
107+ remaining_ranks = dp_size - rank
108+ remaining_images = num_images - img_idx
109+
110+ if remaining_images <= 0 :
111+ break
112+
113+ # Dynamic target: distribute remaining patches evenly among remaining ranks
114+ target = remaining_patches / remaining_ranks
115+
116+ # Must leave at least 1 image for each remaining rank
117+ max_images = remaining_images - (remaining_ranks - 1 )
118+
119+ # Greedily add images until we reach the target load or hit the max
120+ count = 0
121+ while img_idx < num_images and count < max_images :
102122 image_assignments [rank ].append (img_idx )
103123 rank_loads [rank ] += patch_counts [img_idx ]
104- start = end
124+ img_idx += 1
125+ count += 1
126+
127+ # Stop early once we've reached the target (always take at least 1)
128+ if rank_loads [rank ] >= target :
129+ break
130+
131+ remaining_patches -= rank_loads [rank ]
105132
106133 return image_assignments , rank_loads
107134
@@ -136,23 +163,32 @@ def prepare_local_vision_inputs(
136163 [],
137164 )
138165
139- patch_counts = (grid_thw [:, 0 ] * grid_thw [:, 1 ] * grid_thw [:, 2 ]).tolist ()
140- cumsum = [0 ]
141- for c in patch_counts :
142- cumsum .append (cumsum [- 1 ] + c )
166+ # local_indices are contiguous (e.g. [2, 3, 4]), so use tensor slicing
167+ first_img_idx = local_indices [0 ]
168+ last_img_idx = local_indices [- 1 ]
169+
170+ # Compute patch offsets using cumsum
171+ patch_counts = get_image_patch_counts (grid_thw )
172+ patch_counts_tensor = torch .tensor (patch_counts , device = grid_thw .device , dtype = torch .long )
173+ offsets = torch .cat (
174+ (
175+ torch .tensor ([0 ], device = grid_thw .device , dtype = torch .long ),
176+ torch .cumsum (patch_counts_tensor , dim = 0 ),
177+ )
178+ )
143179
144- local_patches = []
145- local_grids = []
146- for idx in local_indices :
147- start , end = cumsum [idx ], cumsum [idx + 1 ]
148- local_patches .append (pixel_values [start :end ])
149- local_grids .append (grid_thw [idx : idx + 1 ])
180+ start_patch = offsets [first_img_idx ].item ()
181+ end_patch = offsets [last_img_idx + 1 ].item ()
150182
151- local_pixel_values = torch . cat ( local_patches , dim = 0 )
152- local_grid_thw = torch . cat ( local_grids , dim = 0 )
183+ local_pixel_values = pixel_values [ start_patch : end_patch ]
184+ local_grid_thw = grid_thw [ first_img_idx : last_img_idx + 1 ]
153185
154- expected_patches = sum (patch_counts [idx ] for idx in local_indices )
155- assert local_pixel_values .shape [0 ] == expected_patches
186+ expected_patches = end_patch - start_patch
187+ assert local_pixel_values .shape [0 ] == expected_patches , (
188+ f"[Vision DP] Local patch count mismatch: "
189+ f"extracted={ local_pixel_values .shape [0 ]} , expected={ expected_patches } , "
190+ f"local_indices={ local_indices } "
191+ )
156192
157193 return local_pixel_values , local_grid_thw , local_indices
158194
@@ -161,28 +197,22 @@ class GatherVisionEmbeddings(Function):
161197 """All-gather vision embeddings with gradient support.
162198
163199 Contiguous assignment means simple concat without reordering.
164- Backward: scales gradients by dp_size to compensate for partial processing.
200+ Backward: all_reduce(SUM) to aggregate gradients from all sequence shards,
201+ then slice to extract this rank's image gradients.
165202 """
166203
167204 @staticmethod
168- def forward (ctx , local_embeddings , dp_group , grad_scaler = True ):
169- ctx .grad_scaler = grad_scaler
205+ def forward (ctx , local_embeddings , dp_group , all_counts : list [int ]):
170206 dp_size = dist .get_world_size (dp_group )
171207 dp_rank = dist .get_rank (dp_group )
172208 ctx .dp_size = dp_size
209+ ctx .dp_group = dp_group
210+ ctx .all_counts = all_counts
211+ ctx .dp_rank = dp_rank
173212
174213 if dp_size == 1 :
175214 return local_embeddings
176215
177- local_count = torch .tensor (
178- [local_embeddings .shape [0 ]], dtype = torch .long , device = local_embeddings .device
179- )
180- all_counts = [torch .zeros_like (local_count ) for _ in range (dp_size )]
181- dist .all_gather (all_counts , local_count , group = dp_group )
182- all_counts = [c .item () for c in all_counts ]
183- ctx .all_counts = all_counts
184- ctx .dp_rank = dp_rank
185-
186216 max_count = max (all_counts ) if all_counts else 0
187217 if max_count == 0 :
188218 return local_embeddings
@@ -211,38 +241,41 @@ def forward(ctx, local_embeddings, dp_group, grad_scaler=True):
211241 @staticmethod
212242 def backward (ctx , grad_output ):
213243 dp_size = ctx .dp_size
214- grad_scaler = ctx .grad_scaler
215244
216245 if dp_size == 1 :
217246 return grad_output , None , None
218247
219248 all_counts = ctx .all_counts
220249 dp_rank = ctx .dp_rank
250+ dp_group = ctx .dp_group
221251
222- if grad_scaler :
223- grad_output = grad_output * dp_size
252+ # Aggregate gradient contributions from all SP ranks.
253+ # Each rank only has non-zero grad for vision tokens in its own
254+ # sequence shard. Summing across ranks recovers the complete
255+ # gradient for every image before we slice by image assignment.
256+ dist .all_reduce (grad_output , op = dist .ReduceOp .SUM , group = dp_group )
224257
225258 start = sum (all_counts [:dp_rank ])
226259 end = start + all_counts [dp_rank ]
227260 local_grad = grad_output [start :end ]
228261 return local_grad , None , None
229262
230263
231- def gather_vision_embeddings (local_embeddings , dp_group = None , grad_scaler = True ):
264+ def gather_vision_embeddings (local_embeddings , dp_group , all_counts : list [ int ] ):
232265 """All-gather vision embeddings from all DP ranks.
233266
234267 Args:
235268 local_embeddings: This rank's vision embeddings.
236269 dp_group: Process group for all-gather. Defaults to Ulysses group.
237- grad_scaler: Whether to scale gradients in backward pass .
270+ all_counts: Pre-computed embedding counts per rank (avoids an all_gather) .
238271
239272 Returns:
240273 All-gathered embeddings concatenated across ranks.
241274 """
242275 dp_group = get_ulysses_group () if dp_group is None else dp_group
243276 if dp_group is None or dist .get_world_size (dp_group ) == 1 :
244277 return local_embeddings
245- return GatherVisionEmbeddings .apply (local_embeddings , dp_group , grad_scaler )
278+ return GatherVisionEmbeddings .apply (local_embeddings , dp_group , all_counts )
246279
247280
248281def create_dp_vision_forward (original_forward ):
@@ -269,8 +302,12 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs):
269302 dp_group = get_ulysses_group ()
270303 dp_rank = dist .get_rank (dp_group )
271304
305+ # Move grid_thw to CPU once to avoid repeated GPU->CPU syncs in
306+ # metadata helpers (grid_thw is a tiny [num_images, 3] tensor).
307+ grid_thw_cpu = grid_thw .cpu ()
308+
272309 # Step 1: Get image assignment
273- patch_counts = get_image_patch_counts (grid_thw )
310+ patch_counts = get_image_patch_counts (grid_thw_cpu )
274311 total_patches = sum (patch_counts )
275312 assert hidden_states .shape [0 ] == total_patches
276313
@@ -280,10 +317,10 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs):
280317 elif hasattr (self , "spatial_merge_size" ):
281318 spatial_merge_size = self .spatial_merge_size
282319
283- embedding_counts = get_image_embedding_counts (grid_thw , spatial_merge_size )
320+ embedding_counts = get_image_embedding_counts (grid_thw_cpu , spatial_merge_size )
284321 total_embeddings = sum (embedding_counts )
285322
286- image_assignments , rank_loads = assign_images_to_dp_ranks (patch_counts , dp_size )
323+ image_assignments , _ = assign_images_to_dp_ranks (patch_counts , dp_size )
287324
288325 # Step 2: Extract local inputs
289326 local_pixels , local_grid_thw , local_indices = prepare_local_vision_inputs (
@@ -328,7 +365,9 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs):
328365 local_embeddings , deepstack_outputs = local_embeddings [0 ], local_embeddings [1 :]
329366
330367 # Step 4: All-gather
331- all_embeddings = gather_vision_embeddings (local_embeddings , dp_group )
368+ # Compute per-rank embedding counts locally (grid_thw is replicated on all ranks)
369+ all_counts = [sum (embedding_counts [i ] for i in image_assignments [r ]) for r in range (dp_size )]
370+ all_embeddings = gather_vision_embeddings (local_embeddings , dp_group , all_counts )
332371 assert all_embeddings .shape [0 ] == total_embeddings
333372
334373 if deepstack_outputs is not None :
@@ -339,10 +378,10 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs):
339378 # List of tensors (one per deepstack layer)
340379 gathered_list = []
341380 for single_emb in ds_emb :
342- gathered_list .append (gather_vision_embeddings (single_emb , dp_group ))
381+ gathered_list .append (gather_vision_embeddings (single_emb , dp_group , all_counts ))
343382 gathered_deepstack .append (gathered_list )
344383 elif isinstance (ds_emb , torch .Tensor ):
345- gathered_deepstack .append (gather_vision_embeddings (ds_emb , dp_group ))
384+ gathered_deepstack .append (gather_vision_embeddings (ds_emb , dp_group , all_counts ))
346385 else :
347386 gathered_deepstack .append (ds_emb )
348387 return (all_embeddings , * gathered_deepstack )
0 commit comments