2323import torch
2424from accelerate import init_empty_weights
2525from nemo_automodel import (
26- NeMoAutoModelForCausalLM ,
2726 NeMoAutoModelForSequenceClassification ,
2827)
29- from nemo_automodel .components ._transformers .utils import sliding_window_overwrite
28+ from nemo_automodel .components ._transformers .utils import (
29+ sliding_window_overwrite ,
30+ )
3031from nemo_automodel .components .distributed .cp_utils import (
3132 create_context_parallel_ctx ,
3233 get_train_context ,
5657from torch .distributed .tensor import DTensor , Shard
5758from transformers import (
5859 AutoConfig ,
60+ AutoProcessor ,
5961 AutoTokenizer ,
6062)
6163from transformers .models .gemma3 .modeling_gemma3 import Gemma3ForCausalLM
7981 get_handle_from_tensor ,
8082 get_runtime_env_for_policy_worker ,
8183 import_class_from_path ,
84+ resolve_model_class ,
8285)
8386from nemo_rl .utils .native_checkpoint import (
8487 load_checkpoint ,
@@ -105,12 +108,19 @@ def __init__(
105108 self ,
106109 config : PolicyConfig ,
107110 tokenizer : AutoTokenizer ,
111+ processor : Optional [AutoProcessor ] = None ,
108112 weights_path : Optional [str ] = None ,
109113 optimizer_path : Optional [str ] = None ,
110114 init_optimizer : bool = True ,
111115 init_reference_model : bool = True ,
112116 ** kwargs : Any ,
113117 ):
118+ self .tokenizer = tokenizer
119+ self .processor = processor
120+ self .is_vlm = processor is not None
121+
122+ print (f"Initializing DTensorPolicyWorkerV2 with is_vlm={ self .is_vlm } " )
123+
114124 self .is_generation_colocated = None
115125 if "generation" in config and config ["generation" ] is not None :
116126 self .is_generation_colocated = config ["generation" ]["colocated" ]["enabled" ]
@@ -146,6 +156,9 @@ def __init__(
146156 print (f"[Rank { self .rank } ] Loading model { model_name } on CPU..." )
147157 self .enable_seq_packing = self .cfg ["sequence_packing" ]["enabled" ]
148158 if self .enable_seq_packing :
159+ assert not self .is_vlm , (
160+ "Sequence packing is not supported for VLM models. Please set policy.sequence_packing.enabled = False to train VLM models."
161+ )
149162 print (
150163 f"[Rank { self .rank } ] Sequence packing is enabled for model { model_name } "
151164 )
@@ -195,7 +208,8 @@ def __init__(
195208 else :
196209 raise ValueError (f"Unknown reward model type: { rm_type } " )
197210 else :
198- model_class = NeMoAutoModelForCausalLM
211+ # DO NOT assume AutoModelForCausalLM, multimodal models can inherit from AutoModelForImageTextToText, AutoModelForTextToWaveform, etc.
212+ model_class = resolve_model_class (model_config .model_type )
199213
200214 full_state_dict = None
201215 if self .rank == 0 :
@@ -205,6 +219,7 @@ def __init__(
205219 device_map = "cpu" , # load weights onto CPU initially
206220 trust_remote_code = True ,
207221 config = model_config ,
222+ torch_dtype = str (model_config .torch_dtype ),
208223 )
209224
210225 full_state_dict = model .state_dict ()
@@ -224,19 +239,12 @@ def __init__(
224239 if self .enable_seq_packing
225240 else None ,
226241 trust_remote_code = True ,
242+ torch_dtype = str (model_config .torch_dtype ),
227243 )
228244
229245 if self .model .config .pad_token_id is None :
230246 self .model .config .pad_token_id = tokenizer .pad_token_id
231247
232- # caching since this property is not always preserved after FSDP
233- self .tokenizer = tokenizer
234-
235- # ------------------------------------------------
236- # 3) Move to GPU + Composable FSDP
237- # (Initialize device mesh, shard submodules, then shard entire model)
238- # ------------------------------------------------
239-
240248 tp_size = self .cfg ["dtensor_cfg" ]["tensor_parallel_size" ]
241249 cp_size = self .cfg ["dtensor_cfg" ]["context_parallel_size" ]
242250 if cp_size > 1 and self .enable_seq_packing :
@@ -266,6 +274,10 @@ def __init__(
266274 "See https://github.com/NVIDIA-NeMo/RL/issues/659 for more details."
267275 )
268276
277+ assert not self .is_vlm , (
278+ "Context parallel is yet not supported for VLM models. Please set cp_size = 1 to train VLM models."
279+ )
280+
269281 # For FSDP2 compatibility, we need to support HSDP structure
270282 # For now, we use dp_replicate_size = 1 (no hybrid sharding)
271283 dp_replicate_size = 1
@@ -299,6 +311,10 @@ def __init__(
299311 self .cp_size = cp_size
300312 self .device_mesh = device_mesh
301313
314+ # ------------------------------------------------
315+ # 3) Move to GPU + Composable FSDP
316+ # (Initialize device mesh, shard submodules, then shard entire model)
317+ # ------------------------------------------------
302318 self .model = fsdp2_strategy_parallelize (
303319 self .model ,
304320 device_mesh = self .device_mesh ,
@@ -597,8 +613,18 @@ def train(
597613 ).repeat (batch_size , 1 )
598614 flash_attn_kwargs = {}
599615
616+ # add vlm kwargs to model call
617+ vlm_kwargs = mb .get_multimodal_dict (
618+ as_tensors = True , device = input_ids .device
619+ )
620+ if len (vlm_kwargs ) > 0 :
621+ position_ids = None
622+
600623 context_parallel_ctx = None
601624 if self .cp_size > 1 :
625+ assert len (vlm_kwargs ) == 0 , (
626+ f"multimodal kwargs={ vlm_kwargs } are not supported for context parallel"
627+ )
602628 seq_index = torch .arange (
603629 seq_len , device = input_ids .device
604630 ).repeat (1 , 1 )
@@ -624,6 +650,7 @@ def train(
624650 position_ids = position_ids ,
625651 use_cache = False ,
626652 flash_attn_kwargs = flash_attn_kwargs ,
653+ ** vlm_kwargs ,
627654 )
628655
629656 if self ._is_reward_model :
@@ -632,6 +659,9 @@ def train(
632659 # is not supported for reward models.
633660 assert not flash_attn_kwargs
634661 del model_args ["flash_attn_kwargs" ]
662+ # remove flash_attn_kwargs if there are multimodal kwargs
663+ if len (vlm_kwargs ) > 0 :
664+ del model_args ["flash_attn_kwargs" ]
635665
636666 outputs = self .model (** model_args )
637667
@@ -859,9 +889,15 @@ def get_logprobs(
859889 step += 1
860890 input_ids = lp_batch .get ("input_ids" ).cuda ()
861891 input_lengths = lp_batch .get ("input_lengths" )
892+ vlm_kwargs = lp_batch .get_multimodal_dict (
893+ as_tensors = True , device = input_ids .device
894+ )
862895
863896 batch_size , seq_len = input_ids .shape
864897 if self .enable_seq_packing :
898+ assert len (vlm_kwargs ) == 0 , (
899+ "multimodal kwargs are not supported for sequence packing"
900+ )
865901 input_ids , position_ids , _ = pack_sequences (
866902 input_ids = input_ids ,
867903 input_lengths = input_lengths ,
@@ -901,8 +937,15 @@ def get_logprobs(
901937 (batch_size , seq_len ), dtype = torch .long , device = input_ids .device
902938 )
903939
940+ # if there are multimodal kwargs, we don't need to add position_ids (computed internally)
941+ if len (vlm_kwargs ) > 0 :
942+ position_ids = None
943+
904944 context_parallel_ctx = None
905945 if self .cp_size > 1 :
946+ assert len (vlm_kwargs ) == 0 , (
947+ "multimodal kwargs are not supported for context parallel"
948+ )
906949 seq_index = torch .arange (seq_len , device = input_ids .device ).repeat (
907950 1 , 1
908951 )
@@ -918,13 +961,18 @@ def get_logprobs(
918961
919962 with get_train_context (False , False , context_parallel_ctx )():
920963 with torch .autocast (device_type = "cuda" , dtype = self .dtype ):
921- outputs = self . model (
964+ model_args = dict (
922965 input_ids = input_ids ,
923966 attention_mask = attention_mask_input_all_ones ,
924967 position_ids = position_ids ,
925968 use_cache = False ,
926969 flash_attn_kwargs = flash_attn_kwargs ,
970+ ** vlm_kwargs ,
927971 )
972+ if len (vlm_kwargs ) > 0 :
973+ del model_args ["flash_attn_kwargs" ]
974+
975+ outputs = self .model (** model_args )
928976
929977 logits = outputs .logits
930978
0 commit comments