Skip to content

Commit 1b13eaf

Browse files
aoshen524claude
andcommitted
fix(vision_dp): fix gradient routing, load balancing, and efficiency issues
Address reviewer comments (same fixes as verl PR #5230 and AReaL PR #929): 1. **Gradient routing fix (critical)**: Replace `grad_scaler * dp_size` with `all_reduce(SUM)` in GatherVisionEmbeddings.backward() to aggregate partial sequence gradients before slicing. Fixes silent gradient loss when vision tokens span multiple sequence shard boundaries. 2. **Load-balanced assignment**: Replace count-based chunking with greedy contiguous bin-packing that balances total patch load across ranks. 3. **Remove unnecessary all_gather**: Pass pre-computed `all_counts` from caller instead of doing all_gather in forward. 4. **Idempotency guard**: Extract `_patch_vision_class()` helper with `_vision_dp_patched` attribute check. Add `_unapply_vision_class()` to properly clear the flag on unapply. 5. **Remove Qwen3-VL-MoE dead code**: Remove unreachable qwen3_vl_moe blocks from apply/unapply (not yet in transformers vl_model_mappings). 6. **GPU→CPU sync optimization**: Move `grid_thw.cpu()` to dp_vision_forward entry point to avoid repeated `.tolist()` GPU→CPU syncs. 7. **Tensor slicing**: Replace Python loop + list append in prepare_local_vision_inputs with contiguous tensor slice using cumsum. 8. **Test improvements**: Rename tests, add load balancing test, add gather_none_group test, use parametrize. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c8eba5f commit 1b13eaf

File tree

3 files changed

+252
-250
lines changed

3 files changed

+252
-250
lines changed

roll/utils/context_parallel/monkey_patch.py

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,33 @@ def apply_ulysses_patch():
3838
return patch_info
3939

4040

41+
def _patch_vision_class(cls, key, class_name):
42+
"""Patch a single VisionTransformer class with Vision DP, with idempotency guard."""
43+
from .vision_dp import create_dp_vision_forward
44+
45+
if getattr(cls, "_vision_dp_patched", False):
46+
return
47+
original = cls.forward
48+
_original_vision_forwards[key] = original
49+
cls.forward = create_dp_vision_forward(original)
50+
cls._vision_dp_patched = True
51+
logger.info(f"Monkey patch {class_name}.forward for Vision DP")
52+
53+
4154
def apply_vision_dp_patch():
4255
"""Patch VisionTransformer.forward for Vision Data Parallel.
4356
4457
Distributes whole images across Ulysses SP ranks for parallelized ViT computation.
4558
Each rank processes 1/sp_size of images, then all-gathers embeddings.
4659
4760
This reduces ViT peak memory by ~sp_size x (e.g. SP=4 -> ~4x reduction).
61+
Safe to call multiple times -- each class is only patched once.
4862
"""
49-
from .vision_dp import create_dp_vision_forward
50-
5163
# Patch Qwen2-VL VisionTransformer
5264
try:
5365
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
5466

55-
original = Qwen2VisionTransformerPretrainedModel.forward
56-
_original_vision_forwards["qwen2_vl"] = original
57-
Qwen2VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original)
58-
logger.info("Monkey patch Qwen2VisionTransformerPretrainedModel.forward for Vision DP")
67+
_patch_vision_class(Qwen2VisionTransformerPretrainedModel, "qwen2_vl", "Qwen2VisionTransformerPretrainedModel")
5968
except ImportError as e:
6069
logger.debug(f"Qwen2-VL not available for Vision DP patch: {e}")
6170

@@ -65,71 +74,52 @@ def apply_vision_dp_patch():
6574
Qwen2_5_VisionTransformerPretrainedModel,
6675
)
6776

68-
original = Qwen2_5_VisionTransformerPretrainedModel.forward
69-
_original_vision_forwards["qwen2_5_vl"] = original
70-
Qwen2_5_VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original)
71-
logger.info("Monkey patch Qwen2_5_VisionTransformerPretrainedModel.forward for Vision DP")
77+
_patch_vision_class(
78+
Qwen2_5_VisionTransformerPretrainedModel, "qwen2_5_vl", "Qwen2_5_VisionTransformerPretrainedModel"
79+
)
7280
except ImportError as e:
7381
logger.debug(f"Qwen2.5-VL not available for Vision DP patch: {e}")
7482

7583
# Patch Qwen3-VL VisionModel
7684
try:
7785
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel
7886

79-
original = Qwen3VLVisionModel.forward
80-
_original_vision_forwards["qwen3_vl"] = original
81-
Qwen3VLVisionModel.forward = create_dp_vision_forward(original)
82-
logger.info("Monkey patch Qwen3VLVisionModel.forward for Vision DP")
87+
_patch_vision_class(Qwen3VLVisionModel, "qwen3_vl", "Qwen3VLVisionModel")
8388
except ImportError as e:
8489
logger.debug(f"Qwen3-VL not available for Vision DP patch: {e}")
8590

86-
# Patch Qwen3-VL-MoE VisionModel
87-
try:
88-
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel
8991

90-
original = Qwen3VLMoeVisionModel.forward
91-
_original_vision_forwards["qwen3_vl_moe"] = original
92-
Qwen3VLMoeVisionModel.forward = create_dp_vision_forward(original)
93-
logger.info("Monkey patch Qwen3VLMoeVisionModel.forward for Vision DP")
94-
except ImportError as e:
95-
logger.debug(f"Qwen3-VL-MoE not available for Vision DP patch: {e}")
92+
def _unapply_vision_class(cls, key):
93+
"""Restore a single VisionTransformer class, clearing the idempotency flag."""
94+
if key in _original_vision_forwards:
95+
cls.forward = _original_vision_forwards.pop(key)
96+
cls._vision_dp_patched = False
9697

9798

9899
def unapply_vision_dp_patch():
99100
"""Restore original VisionTransformer.forward methods."""
100-
if "qwen2_vl" in _original_vision_forwards:
101-
try:
102-
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
103-
104-
Qwen2VisionTransformerPretrainedModel.forward = _original_vision_forwards.pop("qwen2_vl")
105-
except ImportError:
106-
pass
107-
108-
if "qwen2_5_vl" in _original_vision_forwards:
109-
try:
110-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
111-
Qwen2_5_VisionTransformerPretrainedModel,
112-
)
101+
try:
102+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
113103

114-
Qwen2_5_VisionTransformerPretrainedModel.forward = _original_vision_forwards.pop("qwen2_5_vl")
115-
except ImportError:
116-
pass
104+
_unapply_vision_class(Qwen2VisionTransformerPretrainedModel, "qwen2_vl")
105+
except ImportError:
106+
pass
117107

118-
if "qwen3_vl" in _original_vision_forwards:
119-
try:
120-
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel
108+
try:
109+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
110+
Qwen2_5_VisionTransformerPretrainedModel,
111+
)
121112

122-
Qwen3VLVisionModel.forward = _original_vision_forwards.pop("qwen3_vl")
123-
except ImportError:
124-
pass
113+
_unapply_vision_class(Qwen2_5_VisionTransformerPretrainedModel, "qwen2_5_vl")
114+
except ImportError:
115+
pass
125116

126-
if "qwen3_vl_moe" in _original_vision_forwards:
127-
try:
128-
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel
117+
try:
118+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel
129119

130-
Qwen3VLMoeVisionModel.forward = _original_vision_forwards.pop("qwen3_vl_moe")
131-
except ImportError:
132-
pass
120+
_unapply_vision_class(Qwen3VLVisionModel, "qwen3_vl")
121+
except ImportError:
122+
pass
133123

134124

135125
def unapply_ulysses_patch():

roll/utils/context_parallel/vision_dp.py

Lines changed: 91 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,17 @@
2020
Strategy: Distribute whole images across DP ranks, not patches within images.
2121
This avoids breaking cu_seqlens semantics while parallelizing ViT computation.
2222
23-
Key difference from text SP:
24-
- Text SP: Split sequence within attention layers, all-to-all per layer
25-
- Vision DP: Split images across ranks, all_gather once at the end
23+
Key design choices:
24+
- Image-level distribution (not patch-level): avoids breaking ViT's internal
25+
cu_seqlens tracking
26+
- Contiguous assignment: rank 0 gets images [0,1,...], rank 1 gets next chunk, etc.
27+
No reordering needed after all-gather.
28+
- Gradient sync in backward: all_reduce(SUM) across SP ranks before slicing to
29+
recover the complete gradient for each image. Without this, gradients from
30+
vision tokens in other ranks' sequence shards would be lost.
31+
- No additional gradient scaling needed: the all_reduce aggregates partial
32+
sequence gradients, making each rank's ViT backward equivalent to the non-DP
33+
baseline. FSDP's dp_sp reduce-scatter then handles DP averaging as usual.
2634
"""
2735

2836
import torch
@@ -70,10 +78,12 @@ def assign_images_to_dp_ranks(
7078
patch_counts: list[int],
7179
dp_size: int,
7280
) -> tuple[list[list[int]], list[int]]:
73-
"""Assign whole images to DP ranks using contiguous distribution.
81+
"""Assign whole images to DP ranks using load-balanced contiguous distribution.
7482
75-
Rank 0 gets images [0, 1, ...], rank 1 gets next chunk, etc.
76-
This ensures no reordering is needed after all-gather.
83+
The algorithm uses greedy contiguous bin-packing:
84+
- Images are assigned in order (contiguous) to preserve ordering after gather
85+
- Split points are chosen to balance total patch load across ranks
86+
- Each rank gets at least one image when num_images >= dp_size
7787
7888
Args:
7989
patch_counts: Number of patches per image.
@@ -91,17 +101,34 @@ def assign_images_to_dp_ranks(
91101
image_assignments: list[list[int]] = [[] for _ in range(dp_size)]
92102
rank_loads = [0] * dp_size
93103

94-
base_size = num_images // dp_size
95-
remainder = num_images % dp_size
96-
97-
start = 0
104+
remaining_patches = sum(patch_counts)
105+
img_idx = 0
98106
for rank in range(dp_size):
99-
chunk_size = base_size + (1 if rank < remainder else 0)
100-
end = start + chunk_size
101-
for img_idx in range(start, end):
107+
remaining_ranks = dp_size - rank
108+
remaining_images = num_images - img_idx
109+
110+
if remaining_images <= 0:
111+
break
112+
113+
# Dynamic target: distribute remaining patches evenly among remaining ranks
114+
target = remaining_patches / remaining_ranks
115+
116+
# Must leave at least 1 image for each remaining rank
117+
max_images = remaining_images - (remaining_ranks - 1)
118+
119+
# Greedily add images until we reach the target load or hit the max
120+
count = 0
121+
while img_idx < num_images and count < max_images:
102122
image_assignments[rank].append(img_idx)
103123
rank_loads[rank] += patch_counts[img_idx]
104-
start = end
124+
img_idx += 1
125+
count += 1
126+
127+
# Stop early once we've reached the target (always take at least 1)
128+
if rank_loads[rank] >= target:
129+
break
130+
131+
remaining_patches -= rank_loads[rank]
105132

106133
return image_assignments, rank_loads
107134

@@ -136,23 +163,32 @@ def prepare_local_vision_inputs(
136163
[],
137164
)
138165

139-
patch_counts = (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist()
140-
cumsum = [0]
141-
for c in patch_counts:
142-
cumsum.append(cumsum[-1] + c)
166+
# local_indices are contiguous (e.g. [2, 3, 4]), so use tensor slicing
167+
first_img_idx = local_indices[0]
168+
last_img_idx = local_indices[-1]
169+
170+
# Compute patch offsets using cumsum
171+
patch_counts = get_image_patch_counts(grid_thw)
172+
patch_counts_tensor = torch.tensor(patch_counts, device=grid_thw.device, dtype=torch.long)
173+
offsets = torch.cat(
174+
(
175+
torch.tensor([0], device=grid_thw.device, dtype=torch.long),
176+
torch.cumsum(patch_counts_tensor, dim=0),
177+
)
178+
)
143179

144-
local_patches = []
145-
local_grids = []
146-
for idx in local_indices:
147-
start, end = cumsum[idx], cumsum[idx + 1]
148-
local_patches.append(pixel_values[start:end])
149-
local_grids.append(grid_thw[idx : idx + 1])
180+
start_patch = offsets[first_img_idx].item()
181+
end_patch = offsets[last_img_idx + 1].item()
150182

151-
local_pixel_values = torch.cat(local_patches, dim=0)
152-
local_grid_thw = torch.cat(local_grids, dim=0)
183+
local_pixel_values = pixel_values[start_patch:end_patch]
184+
local_grid_thw = grid_thw[first_img_idx : last_img_idx + 1]
153185

154-
expected_patches = sum(patch_counts[idx] for idx in local_indices)
155-
assert local_pixel_values.shape[0] == expected_patches
186+
expected_patches = end_patch - start_patch
187+
assert local_pixel_values.shape[0] == expected_patches, (
188+
f"[Vision DP] Local patch count mismatch: "
189+
f"extracted={local_pixel_values.shape[0]}, expected={expected_patches}, "
190+
f"local_indices={local_indices}"
191+
)
156192

157193
return local_pixel_values, local_grid_thw, local_indices
158194

@@ -161,28 +197,22 @@ class GatherVisionEmbeddings(Function):
161197
"""All-gather vision embeddings with gradient support.
162198
163199
Contiguous assignment means simple concat without reordering.
164-
Backward: scales gradients by dp_size to compensate for partial processing.
200+
Backward: all_reduce(SUM) to aggregate gradients from all sequence shards,
201+
then slice to extract this rank's image gradients.
165202
"""
166203

167204
@staticmethod
168-
def forward(ctx, local_embeddings, dp_group, grad_scaler=True):
169-
ctx.grad_scaler = grad_scaler
205+
def forward(ctx, local_embeddings, dp_group, all_counts: list[int]):
170206
dp_size = dist.get_world_size(dp_group)
171207
dp_rank = dist.get_rank(dp_group)
172208
ctx.dp_size = dp_size
209+
ctx.dp_group = dp_group
210+
ctx.all_counts = all_counts
211+
ctx.dp_rank = dp_rank
173212

174213
if dp_size == 1:
175214
return local_embeddings
176215

177-
local_count = torch.tensor(
178-
[local_embeddings.shape[0]], dtype=torch.long, device=local_embeddings.device
179-
)
180-
all_counts = [torch.zeros_like(local_count) for _ in range(dp_size)]
181-
dist.all_gather(all_counts, local_count, group=dp_group)
182-
all_counts = [c.item() for c in all_counts]
183-
ctx.all_counts = all_counts
184-
ctx.dp_rank = dp_rank
185-
186216
max_count = max(all_counts) if all_counts else 0
187217
if max_count == 0:
188218
return local_embeddings
@@ -211,38 +241,41 @@ def forward(ctx, local_embeddings, dp_group, grad_scaler=True):
211241
@staticmethod
212242
def backward(ctx, grad_output):
213243
dp_size = ctx.dp_size
214-
grad_scaler = ctx.grad_scaler
215244

216245
if dp_size == 1:
217246
return grad_output, None, None
218247

219248
all_counts = ctx.all_counts
220249
dp_rank = ctx.dp_rank
250+
dp_group = ctx.dp_group
221251

222-
if grad_scaler:
223-
grad_output = grad_output * dp_size
252+
# Aggregate gradient contributions from all SP ranks.
253+
# Each rank only has non-zero grad for vision tokens in its own
254+
# sequence shard. Summing across ranks recovers the complete
255+
# gradient for every image before we slice by image assignment.
256+
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=dp_group)
224257

225258
start = sum(all_counts[:dp_rank])
226259
end = start + all_counts[dp_rank]
227260
local_grad = grad_output[start:end]
228261
return local_grad, None, None
229262

230263

231-
def gather_vision_embeddings(local_embeddings, dp_group=None, grad_scaler=True):
264+
def gather_vision_embeddings(local_embeddings, dp_group, all_counts: list[int]):
232265
"""All-gather vision embeddings from all DP ranks.
233266
234267
Args:
235268
local_embeddings: This rank's vision embeddings.
236269
dp_group: Process group for all-gather. Defaults to Ulysses group.
237-
grad_scaler: Whether to scale gradients in backward pass.
270+
all_counts: Pre-computed embedding counts per rank (avoids an all_gather).
238271
239272
Returns:
240273
All-gathered embeddings concatenated across ranks.
241274
"""
242275
dp_group = get_ulysses_group() if dp_group is None else dp_group
243276
if dp_group is None or dist.get_world_size(dp_group) == 1:
244277
return local_embeddings
245-
return GatherVisionEmbeddings.apply(local_embeddings, dp_group, grad_scaler)
278+
return GatherVisionEmbeddings.apply(local_embeddings, dp_group, all_counts)
246279

247280

248281
def create_dp_vision_forward(original_forward):
@@ -269,8 +302,12 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs):
269302
dp_group = get_ulysses_group()
270303
dp_rank = dist.get_rank(dp_group)
271304

305+
# Move grid_thw to CPU once to avoid repeated GPU->CPU syncs in
306+
# metadata helpers (grid_thw is a tiny [num_images, 3] tensor).
307+
grid_thw_cpu = grid_thw.cpu()
308+
272309
# Step 1: Get image assignment
273-
patch_counts = get_image_patch_counts(grid_thw)
310+
patch_counts = get_image_patch_counts(grid_thw_cpu)
274311
total_patches = sum(patch_counts)
275312
assert hidden_states.shape[0] == total_patches
276313

@@ -280,10 +317,10 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs):
280317
elif hasattr(self, "spatial_merge_size"):
281318
spatial_merge_size = self.spatial_merge_size
282319

283-
embedding_counts = get_image_embedding_counts(grid_thw, spatial_merge_size)
320+
embedding_counts = get_image_embedding_counts(grid_thw_cpu, spatial_merge_size)
284321
total_embeddings = sum(embedding_counts)
285322

286-
image_assignments, rank_loads = assign_images_to_dp_ranks(patch_counts, dp_size)
323+
image_assignments, _ = assign_images_to_dp_ranks(patch_counts, dp_size)
287324

288325
# Step 2: Extract local inputs
289326
local_pixels, local_grid_thw, local_indices = prepare_local_vision_inputs(
@@ -328,7 +365,9 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs):
328365
local_embeddings, deepstack_outputs = local_embeddings[0], local_embeddings[1:]
329366

330367
# Step 4: All-gather
331-
all_embeddings = gather_vision_embeddings(local_embeddings, dp_group)
368+
# Compute per-rank embedding counts locally (grid_thw is replicated on all ranks)
369+
all_counts = [sum(embedding_counts[i] for i in image_assignments[r]) for r in range(dp_size)]
370+
all_embeddings = gather_vision_embeddings(local_embeddings, dp_group, all_counts)
332371
assert all_embeddings.shape[0] == total_embeddings
333372

334373
if deepstack_outputs is not None:
@@ -339,10 +378,10 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs):
339378
# List of tensors (one per deepstack layer)
340379
gathered_list = []
341380
for single_emb in ds_emb:
342-
gathered_list.append(gather_vision_embeddings(single_emb, dp_group))
381+
gathered_list.append(gather_vision_embeddings(single_emb, dp_group, all_counts))
343382
gathered_deepstack.append(gathered_list)
344383
elif isinstance(ds_emb, torch.Tensor):
345-
gathered_deepstack.append(gather_vision_embeddings(ds_emb, dp_group))
384+
gathered_deepstack.append(gather_vision_embeddings(ds_emb, dp_group, all_counts))
346385
else:
347386
gathered_deepstack.append(ds_emb)
348387
return (all_embeddings, *gathered_deepstack)

0 commit comments

Comments
 (0)