@@ -449,11 +449,21 @@ def wrap_fsdp2(model: torch.nn.Module) -> torch.nn.Module:
449449 "This likely means we need to update the code to support this model."
450450 )
451451
452- # Apply activation checkpointing to each block
453- log_rank_0 (f"🔄 [Phase 2] Applying activation checkpointing to { len (layers )} blocks" )
454- for idx , block in enumerate (layers ):
455- # preserve_rng_state needs to be true so that the backward pass can be accurate
456- layers [idx ] = ptd_checkpoint_wrapper (block , preserve_rng_state = True )
452+ # Apply activation checkpointing to each block.
453+ # VLM models (detected by language_model path) may have non-deterministic
454+ # tensor counts during reentrant recomputation (e.g., M-RoPE), so we skip
455+ # activation checkpointing for them. This uses more memory but avoids
456+ # CheckpointError from tensor count mismatches.
457+ is_vlm_direct = hasattr (model , "model" ) and hasattr (model .model , "language_model" )
458+ if is_vlm_direct :
459+ log_rank_0 (
460+ f"🔄 [Phase 2] Skipping activation checkpointing for VLM ({ len (layers )} blocks)"
461+ )
462+ else :
463+ log_rank_0 (f"🔄 [Phase 2] Applying activation checkpointing to { len (layers )} blocks" )
464+ for idx , block in enumerate (layers ):
465+ # preserve_rng_state needs to be true so that the backward pass can be accurate
466+ layers [idx ] = ptd_checkpoint_wrapper (block , preserve_rng_state = True )
457467
458468 # Build 1D device mesh over all ranks
459469 world_size = dist .get_world_size ()
@@ -577,12 +587,21 @@ def finalize_model_initialization(model: torch.nn.Module, context: ModelInitiali
577587 return model
578588
579589
590+ def _get_text_config (model ):
591+ """Get the text-relevant config, falling back to text_config for VLMs."""
592+ config = model .config
593+ if not hasattr (config , "vocab_size" ) and hasattr (config , "text_config" ):
594+ return config .text_config
595+ return config
596+
597+
580598def align_model_and_tokenizer (model , tokenizer ):
581599 """
582600 Aligns the model's vocabulary and special tokens with the tokenizer.
583601 """
584- if len (tokenizer ) > model .config .vocab_size :
585- print (f"WARNING: tokenizer has { len (tokenizer )} tokens but model has { model .config .vocab_size } vocab size" )
602+ text_config = _get_text_config (model )
603+ if len (tokenizer ) > text_config .vocab_size :
604+ print (f"WARNING: tokenizer has { len (tokenizer )} tokens but model has { text_config .vocab_size } vocab size" )
586605 model .resize_token_embeddings (
587606 int (8 * math .ceil (len (tokenizer ) / 8.0 ))
588607 ) # make the vocab size multiple of 8 for sharding the embedding layer.
@@ -603,22 +622,22 @@ def align_model_and_tokenizer(model, tokenizer):
603622 "Cannot proceed with training - please configure the tokenizer properly."
604623 )
605624
606- # Step 2: Sync all special tokens from tokenizer to model.config
607- # This ensures model. config always reflects tokenizer's special tokens
625+ # Step 2: Sync all special tokens from tokenizer to text_config
626+ # This ensures the config always reflects tokenizer's special tokens
608627 special_tokens = {
609628 "pad" : ("pad_token_id" , "Syncing model pad token id" ),
610629 "bos" : ("bos_token_id" , "Syncing model bos token id" ),
611630 "eos" : ("eos_token_id" , "Syncing model eos token id" ),
612631 }
613632
614633 for token_type , (token_attr , message ) in special_tokens .items ():
615- model_token = getattr (model . config , token_attr )
634+ model_token = getattr (text_config , token_attr , None )
616635 tokenizer_token = getattr (tokenizer , token_attr )
617636
618- # Always sync tokenizer -> model. config when tokenizer has a valid value
637+ # Always sync tokenizer -> config when tokenizer has a valid value
619638 if tokenizer_token is not None and model_token != tokenizer_token :
620639 log_rank_0 (f"{ message } : { model_token } -> { tokenizer_token } " )
621- setattr (model . config , token_attr , tokenizer_token )
640+ setattr (text_config , token_attr , tokenizer_token )
622641
623642 return model
624643
@@ -1042,18 +1061,6 @@ def load_standard_model():
10421061
10431062 def load_osft_model ():
10441063 """Load a model with OSFT (Orthogonal Subspace Fine-Tuning) support."""
1045- # Direct VLMs (no CausalLM class) are not supported for OSFT yet.
1046- # OSFT wraps the base model class and modifies internal weights, which
1047- # requires a CausalLM-compatible architecture. Direct VLMs have a
1048- # different layer structure (model.model.language_model.layers) that
1049- # would need significant OSFT adapter changes.
1050- if is_vlm_for_direct_loading (model_config ):
1051- raise ValueError (
1052- f"OSFT is not supported for direct VLM models (e.g. { model_name_or_path } ). "
1053- "This model has no standalone CausalLM class and cannot be wrapped by OSFT. "
1054- "Use SFT training instead (osft=False)."
1055- )
1056-
10571064 log_rank_0 ("loading OSFT model" )
10581065 # If osft_output_dtype is not specified, use train_dtype for consistency
10591066 effective_osft_output_dtype = osft_output_dtype if osft_output_dtype is not None else train_dtype
0 commit comments