1313else :
1414 old_update_causal_mask = None
1515
16+ # Store original vision forwards for unapply
17+ _original_vision_forwards = {}
18+
1619
1720def apply_ulysses_patch ():
1821 from .ulysses_attention import _flash_attention_forward , _update_causal_mask
@@ -35,6 +38,100 @@ def apply_ulysses_patch():
3538 return patch_info
3639
3740
41+ def apply_vision_dp_patch ():
42+ """Patch VisionTransformer.forward for Vision Data Parallel.
43+
44+ Distributes whole images across Ulysses SP ranks for parallelized ViT computation.
45+ Each rank processes 1/sp_size of images, then all-gathers embeddings.
46+
47+ This reduces ViT peak memory by ~sp_size x (e.g. SP=4 -> ~4x reduction).
48+ """
49+ from .vision_dp import create_dp_vision_forward
50+
51+ # Patch Qwen2-VL VisionTransformer
52+ try :
53+ from transformers .models .qwen2_vl .modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
54+
55+ original = Qwen2VisionTransformerPretrainedModel .forward
56+ _original_vision_forwards ["qwen2_vl" ] = original
57+ Qwen2VisionTransformerPretrainedModel .forward = create_dp_vision_forward (original )
58+ logger .info ("Monkey patch Qwen2VisionTransformerPretrainedModel.forward for Vision DP" )
59+ except ImportError as e :
60+ logger .debug (f"Qwen2-VL not available for Vision DP patch: { e } " )
61+
62+ # Patch Qwen2.5-VL VisionTransformer
63+ try :
64+ from transformers .models .qwen2_5_vl .modeling_qwen2_5_vl import (
65+ Qwen2_5_VisionTransformerPretrainedModel ,
66+ )
67+
68+ original = Qwen2_5_VisionTransformerPretrainedModel .forward
69+ _original_vision_forwards ["qwen2_5_vl" ] = original
70+ Qwen2_5_VisionTransformerPretrainedModel .forward = create_dp_vision_forward (original )
71+ logger .info ("Monkey patch Qwen2_5_VisionTransformerPretrainedModel.forward for Vision DP" )
72+ except ImportError as e :
73+ logger .debug (f"Qwen2.5-VL not available for Vision DP patch: { e } " )
74+
75+ # Patch Qwen3-VL VisionModel
76+ try :
77+ from transformers .models .qwen3_vl .modeling_qwen3_vl import Qwen3VLVisionModel
78+
79+ original = Qwen3VLVisionModel .forward
80+ _original_vision_forwards ["qwen3_vl" ] = original
81+ Qwen3VLVisionModel .forward = create_dp_vision_forward (original )
82+ logger .info ("Monkey patch Qwen3VLVisionModel.forward for Vision DP" )
83+ except ImportError as e :
84+ logger .debug (f"Qwen3-VL not available for Vision DP patch: { e } " )
85+
86+ # Patch Qwen3-VL-MoE VisionModel
87+ try :
88+ from transformers .models .qwen3_vl_moe .modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel
89+
90+ original = Qwen3VLMoeVisionModel .forward
91+ _original_vision_forwards ["qwen3_vl_moe" ] = original
92+ Qwen3VLMoeVisionModel .forward = create_dp_vision_forward (original )
93+ logger .info ("Monkey patch Qwen3VLMoeVisionModel.forward for Vision DP" )
94+ except ImportError as e :
95+ logger .debug (f"Qwen3-VL-MoE not available for Vision DP patch: { e } " )
96+
97+
98+ def unapply_vision_dp_patch ():
99+ """Restore original VisionTransformer.forward methods."""
100+ if "qwen2_vl" in _original_vision_forwards :
101+ try :
102+ from transformers .models .qwen2_vl .modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
103+
104+ Qwen2VisionTransformerPretrainedModel .forward = _original_vision_forwards .pop ("qwen2_vl" )
105+ except ImportError :
106+ pass
107+
108+ if "qwen2_5_vl" in _original_vision_forwards :
109+ try :
110+ from transformers .models .qwen2_5_vl .modeling_qwen2_5_vl import (
111+ Qwen2_5_VisionTransformerPretrainedModel ,
112+ )
113+
114+ Qwen2_5_VisionTransformerPretrainedModel .forward = _original_vision_forwards .pop ("qwen2_5_vl" )
115+ except ImportError :
116+ pass
117+
118+ if "qwen3_vl" in _original_vision_forwards :
119+ try :
120+ from transformers .models .qwen3_vl .modeling_qwen3_vl import Qwen3VLVisionModel
121+
122+ Qwen3VLVisionModel .forward = _original_vision_forwards .pop ("qwen3_vl" )
123+ except ImportError :
124+ pass
125+
126+ if "qwen3_vl_moe" in _original_vision_forwards :
127+ try :
128+ from transformers .models .qwen3_vl_moe .modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel
129+
130+ Qwen3VLMoeVisionModel .forward = _original_vision_forwards .pop ("qwen3_vl_moe" )
131+ except ImportError :
132+ pass
133+
134+
38135def unapply_ulysses_patch ():
39136 global old_flash_attention_forward , old_update_causal_mask
40137 ALL_ATTENTION_FUNCTIONS ["flash_attention_2" ] = old_flash_attention_forward
@@ -47,3 +144,4 @@ def unapply_ulysses_patch():
47144 unapply_hf_flash_attention_ulysses_patch ()
48145 except Exception :
49146 pass
147+ unapply_vision_dp_patch ()
0 commit comments