Skip to content

Commit c8eba5f

Browse files
aoshen524claude
andcommitted
feat(vision): add Vision DP for parallel ViT computation across Ulysses SP ranks
Distribute whole images across Ulysses SP ranks for parallelized ViT computation, reducing ViT peak memory by ~sp_size x (e.g. SP=4 -> ~4x ViT memory reduction). Key changes: - Add roll/utils/context_parallel/vision_dp.py with image distribution utilities, GatherVisionEmbeddings autograd function, and model-agnostic VisionTransformer wrapper - Add apply_vision_dp_patch() in monkey_patch.py for Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3-VL-MoE VisionTransformer classes - Integrate into DeepSpeed strategy (both inference and training workers) - Add 17 unit tests covering all utility functions, edge cases, and integration workflows Ported from verl (verl-project/verl#5230). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ae69fd8 commit c8eba5f

File tree

5 files changed

+703
-3
lines changed

5 files changed

+703
-3
lines changed

roll/distributed/strategy/deepspeed_strategy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from roll.third_party.deepspeed.model_update import DeepSpeedWeightUpdater
2424
from roll.third_party.deepspeed.offload_states_patch import bind_deepspeed_offload_states_func
2525
from roll.utils.collective import collective
26-
from roll.utils.context_parallel import get_ulysses_group, set_upg_manager
26+
from roll.utils.context_parallel import apply_vision_dp_patch, get_ulysses_group, set_upg_manager
2727
from roll.utils.deepspeed_utils import get_optimizer_grouped_parameters
2828
from roll.utils.functionals import append_to_dict, entropy_from_logits, log_probs_from_logits
2929
from roll.utils.constants import IGNORE_INDEX
@@ -69,6 +69,7 @@ def initialize(self, model_provider):
6969
if (cp_size := self.worker_config.model_args.ulysses_size) > 1:
7070
if current_platform.apply_ulysses_patch() is not None:
7171
set_upg_manager(ulysses_size=cp_size, rank=global_rank, world_size=world_size)
72+
apply_vision_dp_patch()
7273
else:
7374
cp_size = 1
7475

@@ -332,6 +333,7 @@ def initialize(self, model_provider):
332333
if (cp_size := self.worker_config.model_args.ulysses_size) > 1:
333334
current_platform.apply_ulysses_patch()
334335
set_upg_manager(ulysses_size=cp_size, rank=global_rank, world_size=world_size)
336+
apply_vision_dp_patch()
335337

336338
self.worker.rank_info.dp_rank = global_rank // cp_size
337339
self.worker.rank_info.dp_size = world_size // cp_size
Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,17 @@
11
from roll.utils.context_parallel.globals import get_ulysses_group, set_upg_manager
2-
from roll.utils.context_parallel.monkey_patch import apply_ulysses_patch, unapply_ulysses_patch
2+
from roll.utils.context_parallel.monkey_patch import (
3+
apply_ulysses_patch,
4+
apply_vision_dp_patch,
5+
unapply_ulysses_patch,
6+
unapply_vision_dp_patch,
7+
)
38

4-
__all__ = ["set_upg_manager", "get_ulysses_group", "apply_ulysses_patch", "unapply_ulysses_patch"]
9+
10+
__all__ = [
11+
"set_upg_manager",
12+
"get_ulysses_group",
13+
"apply_ulysses_patch",
14+
"apply_vision_dp_patch",
15+
"unapply_ulysses_patch",
16+
"unapply_vision_dp_patch",
17+
]

roll/utils/context_parallel/monkey_patch.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
else:
1414
old_update_causal_mask = None
1515

16+
# Store original vision forwards for unapply
17+
_original_vision_forwards = {}
18+
1619

1720
def apply_ulysses_patch():
1821
from .ulysses_attention import _flash_attention_forward, _update_causal_mask
@@ -35,6 +38,100 @@ def apply_ulysses_patch():
3538
return patch_info
3639

3740

41+
def apply_vision_dp_patch():
42+
"""Patch VisionTransformer.forward for Vision Data Parallel.
43+
44+
Distributes whole images across Ulysses SP ranks for parallelized ViT computation.
45+
Each rank processes 1/sp_size of images, then all-gathers embeddings.
46+
47+
This reduces ViT peak memory by ~sp_size x (e.g. SP=4 -> ~4x reduction).
48+
"""
49+
from .vision_dp import create_dp_vision_forward
50+
51+
# Patch Qwen2-VL VisionTransformer
52+
try:
53+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
54+
55+
original = Qwen2VisionTransformerPretrainedModel.forward
56+
_original_vision_forwards["qwen2_vl"] = original
57+
Qwen2VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original)
58+
logger.info("Monkey patch Qwen2VisionTransformerPretrainedModel.forward for Vision DP")
59+
except ImportError as e:
60+
logger.debug(f"Qwen2-VL not available for Vision DP patch: {e}")
61+
62+
# Patch Qwen2.5-VL VisionTransformer
63+
try:
64+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
65+
Qwen2_5_VisionTransformerPretrainedModel,
66+
)
67+
68+
original = Qwen2_5_VisionTransformerPretrainedModel.forward
69+
_original_vision_forwards["qwen2_5_vl"] = original
70+
Qwen2_5_VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original)
71+
logger.info("Monkey patch Qwen2_5_VisionTransformerPretrainedModel.forward for Vision DP")
72+
except ImportError as e:
73+
logger.debug(f"Qwen2.5-VL not available for Vision DP patch: {e}")
74+
75+
# Patch Qwen3-VL VisionModel
76+
try:
77+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel
78+
79+
original = Qwen3VLVisionModel.forward
80+
_original_vision_forwards["qwen3_vl"] = original
81+
Qwen3VLVisionModel.forward = create_dp_vision_forward(original)
82+
logger.info("Monkey patch Qwen3VLVisionModel.forward for Vision DP")
83+
except ImportError as e:
84+
logger.debug(f"Qwen3-VL not available for Vision DP patch: {e}")
85+
86+
# Patch Qwen3-VL-MoE VisionModel
87+
try:
88+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel
89+
90+
original = Qwen3VLMoeVisionModel.forward
91+
_original_vision_forwards["qwen3_vl_moe"] = original
92+
Qwen3VLMoeVisionModel.forward = create_dp_vision_forward(original)
93+
logger.info("Monkey patch Qwen3VLMoeVisionModel.forward for Vision DP")
94+
except ImportError as e:
95+
logger.debug(f"Qwen3-VL-MoE not available for Vision DP patch: {e}")
96+
97+
98+
def unapply_vision_dp_patch():
99+
"""Restore original VisionTransformer.forward methods."""
100+
if "qwen2_vl" in _original_vision_forwards:
101+
try:
102+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
103+
104+
Qwen2VisionTransformerPretrainedModel.forward = _original_vision_forwards.pop("qwen2_vl")
105+
except ImportError:
106+
pass
107+
108+
if "qwen2_5_vl" in _original_vision_forwards:
109+
try:
110+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
111+
Qwen2_5_VisionTransformerPretrainedModel,
112+
)
113+
114+
Qwen2_5_VisionTransformerPretrainedModel.forward = _original_vision_forwards.pop("qwen2_5_vl")
115+
except ImportError:
116+
pass
117+
118+
if "qwen3_vl" in _original_vision_forwards:
119+
try:
120+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel
121+
122+
Qwen3VLVisionModel.forward = _original_vision_forwards.pop("qwen3_vl")
123+
except ImportError:
124+
pass
125+
126+
if "qwen3_vl_moe" in _original_vision_forwards:
127+
try:
128+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel
129+
130+
Qwen3VLMoeVisionModel.forward = _original_vision_forwards.pop("qwen3_vl_moe")
131+
except ImportError:
132+
pass
133+
134+
38135
def unapply_ulysses_patch():
39136
global old_flash_attention_forward, old_update_causal_mask
40137
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = old_flash_attention_forward
@@ -47,3 +144,4 @@ def unapply_ulysses_patch():
47144
unapply_hf_flash_attention_ulysses_patch()
48145
except Exception:
49146
pass
147+
unapply_vision_dp_patch()

0 commit comments

Comments
 (0)