Skip to content

Commit 9aef79d

Browse files
RobotSailclaude
andcommitted
fix VLM OSFT support for direct-loaded models
- Remove unnecessary OSFT guard for direct VLMs (patterns match fine) - Add _get_text_config() helper for VLM config fallback in align_model_and_tokenizer (vocab_size, pad/bos/eos_token_id) - Fix model.config.pad_token_id access in train.py for VLM configs - Skip activation checkpointing for direct VLM models (M-RoPE layers produce non-deterministic tensor counts during reentrant recomputation) - Use dynamic ports (_get_free_port) in model_validation.py to prevent port conflicts between sequential tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b16b479 commit 9aef79d

File tree

2 files changed

+32
-25
lines changed

2 files changed

+32
-25
lines changed

src/mini_trainer/setup_model_for_training.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -449,11 +449,21 @@ def wrap_fsdp2(model: torch.nn.Module) -> torch.nn.Module:
449449
"This likely means we need to update the code to support this model."
450450
)
451451

452-
# Apply activation checkpointing to each block
453-
log_rank_0(f"🔄 [Phase 2] Applying activation checkpointing to {len(layers)} blocks")
454-
for idx, block in enumerate(layers):
455-
# preserve_rng_state needs to be true so that the backward pass can be accurate
456-
layers[idx] = ptd_checkpoint_wrapper(block, preserve_rng_state=True)
452+
# Apply activation checkpointing to each block.
453+
# VLM models (detected by language_model path) may have non-deterministic
454+
# tensor counts during reentrant recomputation (e.g., M-RoPE), so we skip
455+
# activation checkpointing for them. This uses more memory but avoids
456+
# CheckpointError from tensor count mismatches.
457+
is_vlm_direct = hasattr(model, "model") and hasattr(model.model, "language_model")
458+
if is_vlm_direct:
459+
log_rank_0(
460+
f"🔄 [Phase 2] Skipping activation checkpointing for VLM ({len(layers)} blocks)"
461+
)
462+
else:
463+
log_rank_0(f"🔄 [Phase 2] Applying activation checkpointing to {len(layers)} blocks")
464+
for idx, block in enumerate(layers):
465+
# preserve_rng_state needs to be true so that the backward pass can be accurate
466+
layers[idx] = ptd_checkpoint_wrapper(block, preserve_rng_state=True)
457467

458468
# Build 1D device mesh over all ranks
459469
world_size = dist.get_world_size()
@@ -577,12 +587,21 @@ def finalize_model_initialization(model: torch.nn.Module, context: ModelInitiali
577587
return model
578588

579589

590+
def _get_text_config(model):
591+
"""Get the text-relevant config, falling back to text_config for VLMs."""
592+
config = model.config
593+
if not hasattr(config, "vocab_size") and hasattr(config, "text_config"):
594+
return config.text_config
595+
return config
596+
597+
580598
def align_model_and_tokenizer(model, tokenizer):
581599
"""
582600
Aligns the model's vocabulary and special tokens with the tokenizer.
583601
"""
584-
if len(tokenizer) > model.config.vocab_size:
585-
print(f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size")
602+
text_config = _get_text_config(model)
603+
if len(tokenizer) > text_config.vocab_size:
604+
print(f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {text_config.vocab_size} vocab size")
586605
model.resize_token_embeddings(
587606
int(8 * math.ceil(len(tokenizer) / 8.0))
588607
) # make the vocab size multiple of 8 for sharding the embedding layer.
@@ -603,22 +622,22 @@ def align_model_and_tokenizer(model, tokenizer):
603622
"Cannot proceed with training - please configure the tokenizer properly."
604623
)
605624

606-
# Step 2: Sync all special tokens from tokenizer to model.config
607-
# This ensures model.config always reflects tokenizer's special tokens
625+
# Step 2: Sync all special tokens from tokenizer to text_config
626+
# This ensures the config always reflects tokenizer's special tokens
608627
special_tokens = {
609628
"pad": ("pad_token_id", "Syncing model pad token id"),
610629
"bos": ("bos_token_id", "Syncing model bos token id"),
611630
"eos": ("eos_token_id", "Syncing model eos token id"),
612631
}
613632

614633
for token_type, (token_attr, message) in special_tokens.items():
615-
model_token = getattr(model.config, token_attr)
634+
model_token = getattr(text_config, token_attr, None)
616635
tokenizer_token = getattr(tokenizer, token_attr)
617636

618-
# Always sync tokenizer -> model.config when tokenizer has a valid value
637+
# Always sync tokenizer -> config when tokenizer has a valid value
619638
if tokenizer_token is not None and model_token != tokenizer_token:
620639
log_rank_0(f"{message}: {model_token} -> {tokenizer_token}")
621-
setattr(model.config, token_attr, tokenizer_token)
640+
setattr(text_config, token_attr, tokenizer_token)
622641

623642
return model
624643

@@ -1042,18 +1061,6 @@ def load_standard_model():
10421061

10431062
def load_osft_model():
10441063
"""Load a model with OSFT (Orthogonal Subspace Fine-Tuning) support."""
1045-
# Direct VLMs (no CausalLM class) are not supported for OSFT yet.
1046-
# OSFT wraps the base model class and modifies internal weights, which
1047-
# requires a CausalLM-compatible architecture. Direct VLMs have a
1048-
# different layer structure (model.model.language_model.layers) that
1049-
# would need significant OSFT adapter changes.
1050-
if is_vlm_for_direct_loading(model_config):
1051-
raise ValueError(
1052-
f"OSFT is not supported for direct VLM models (e.g. {model_name_or_path}). "
1053-
"This model has no standalone CausalLM class and cannot be wrapped by OSFT. "
1054-
"Use SFT training instead (osft=False)."
1055-
)
1056-
10571064
log_rank_0("loading OSFT model")
10581065
# If osft_output_dtype is not specified, use train_dtype for consistency
10591066
effective_osft_output_dtype = osft_output_dtype if osft_output_dtype is not None else train_dtype

src/mini_trainer/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1239,7 +1239,7 @@ def main(
12391239
batch_size=batch_size,
12401240
max_tokens_per_gpu=max_tokens_per_gpu,
12411241
seed=seed,
1242-
pad_token_id=model.config.pad_token_id,
1242+
pad_token_id=getattr(getattr(model.config, "text_config", model.config), "pad_token_id", None),
12431243
validation_split=validation_split,
12441244
pretraining_config=pretraining_config,
12451245
)

0 commit comments

Comments
 (0)