1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- import os
1514from collections import defaultdict
16- from typing import Any , Iterable , Optional
15+ from typing import Any , Optional
1716
1817import torch
1918from torch .multiprocessing .reductions import rebuild_cuda_tensor
2827 )
2928
3029
31- def _patch_gemma3_mm ():
32- """Patch gemma3_mm.py to support new HF multimodal format (post transformers v4.52).
33-
34- Patch taken from:https://github.com/vllm-project/vllm/pull/19151/files#diff-5890909300e4e6c3160444e4587ec3fd80498bb83f598b22ce81337f75992b06
35- """
36- from packaging .version import Version as PkgVersion
37-
38- assert PkgVersion (vllm .__version__ ) < PkgVersion ("0.9.2" ), (
39- f"You are using vllm version { vllm .__version__ } . "
40- "Please remove this patch (_patch_gemma3_mm in nemo_rl/models/generation/vllm_backend.py) "
41- "since it is included in vllm>=0.9.2."
42- )
43-
44- from vllm .logger import init_logger
45- from vllm .model_executor .models import gemma3_mm
46- from vllm .model_executor .models .utils import (
47- AutoWeightsLoader ,
48- WeightsMapper ,
49- )
50-
51- logger = init_logger ("gemma3_mm_patch" )
52-
53- gemma3_mm .Gemma3ForConditionalGeneration .hf_to_vllm_mapper = WeightsMapper (
54- orig_to_new_prefix = {
55- # mapping for new names in checkpoint saved after transformers v4.52
56- "model.language_model." : "language_model.model." ,
57- "model.vision_tower." : "vision_tower." ,
58- "model.multi_modal_projector." : "multi_modal_projector." ,
59- "lm_head." : "language_model.lm_head." ,
60- }
61- )
62-
63- def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]) -> set [str ]:
64- loader = AutoWeightsLoader (self )
65- return loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
66-
67- gemma3_mm .Gemma3ForConditionalGeneration .load_weights = load_weights
68- logger .info ("Successfully patched gemma3_mm.py in vllm_backend." )
69-
70-
71- _patch_gemma3_mm ()
72-
73-
7430class VllmInternalWorkerExtension :
7531 def init_collective (
7632 self , rank_prefix : int , ip : str , port : int , world_size : int
@@ -82,10 +38,6 @@ def init_collective(
8238 local_rank = torch .distributed .get_rank ()
8339 rank = rank_prefix + local_rank + 1 # 1 is the head node of the train cluster
8440
85- # Temporary fix for vllm==0.9.0 which overrides the NCCL_CUMEM_ENABLE to 0 and causes
86- # https://github.com/NVIDIA-NeMo/RL/issues/564. This can be removed after it is upgraded to vllm>=0.9.1rc1.
87- os .environ ["NCCL_CUMEM_ENABLE" ] = "1"
88-
8941 pg = StatelessProcessGroup .create (
9042 host = ip , port = port , rank = rank , world_size = world_size
9143 )
0 commit comments