Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions src/mini_trainer/osft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from mini_trainer.gpt_oss_utils import is_gpt_oss_model
from mini_trainer.utils import get_control_process_group, log_rank_0
from mini_trainer.vlm_utils import extract_causal_lm_from_vlm, is_vlm_with_causal_lm

# Memory optimization constants
OSFT_CACHE_CLEAR_INTERVAL = int(
Expand Down Expand Up @@ -315,6 +316,18 @@ def _reconstruct_weight(
# "experts.down_proj",
]
},
"qwen3_5": {
"patterns": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"linear_attn.out_proj",
"mlp.gate_proj",
"mlp.down_proj",
"mlp.up_proj",
]
},
"default": {
"patterns": [
"self_attn.q_proj",
Expand Down Expand Up @@ -346,8 +359,8 @@ def _reconstruct_weight(
"mistral": "mistral",
"granite": "granite",
"gpt2": "gpt2",
# Easy to add more mappings
# "phi": "phi",
"qwen3_5": "qwen3_5",
"qwen3.5": "qwen3_5",
}


Expand Down Expand Up @@ -767,12 +780,26 @@ def _load_model_memory_efficient(
if dist.get_rank() == 0:
with torch.no_grad():
log_rank_0(f"📥 Loading base model to CPU in {load_dtype}...")
base_model = base_model_class.from_pretrained(
pretrained_model_name_or_path,
*model_args,
**final_base_kwargs,

# Check if this is a VLM wrapping a CausalLM text backbone
from transformers import AutoConfig

_pre_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True
)

if is_vlm_with_causal_lm(_pre_config):
log_rank_0("🔄 VLM detected – extracting CausalLM text backbone for OSFT")
base_model = extract_causal_lm_from_vlm(
pretrained_model_name_or_path, final_base_kwargs
)
else:
base_model = base_model_class.from_pretrained(
pretrained_model_name_or_path,
*model_args,
**final_base_kwargs,
)

align_fn = osft_class_kwargs.get("lazy_init_tokenizer_align_fn")
if align_fn:
base_model = align_fn(base_model)
Expand Down Expand Up @@ -888,6 +915,10 @@ def create_osft_model_class(base_cls) -> type[OSFTModel]:
class ModelWithOSFT(base_cls):
osft_config: dict[str, int]

@classmethod
def _can_set_experts_implementation(cls):
return base_cls._can_set_experts_implementation()

def __init__(
self,
config,
Expand Down
78 changes: 72 additions & 6 deletions src/mini_trainer/setup_model_for_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
log_rank_0,
patch_target_module,
)
from mini_trainer.vlm_utils import (
extract_causal_lm_from_vlm,
has_mrope,
is_vlm_with_causal_lm,
)


def _distributed_initialized() -> bool:
Expand Down Expand Up @@ -817,12 +822,20 @@ def setup_sft_model_distributed(
state_dict = None
buffer_dict = None

# Check if this is a VLM wrapping a CausalLM text backbone
_model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
is_vlm = is_vlm_with_causal_lm(_model_config)

if dist.get_rank() == 0:
log_rank_0("rank 0: loading model to CPU")
try:
with torch.no_grad():
# Default load targets CPU when no device_map or accelerate is present
cpu_model = ModelClass.from_pretrained(**base_model_args)
if is_vlm:
# VLM model: load full VLM and extract CausalLM text backbone
cpu_model = extract_causal_lm_from_vlm(model_name_or_path, base_model_args)
else:
# Standard CausalLM: load directly
cpu_model = ModelClass.from_pretrained(**base_model_args)
cpu_model = align_model_and_tokenizer(cpu_model, tokenizer)
config = cpu_model.config
state_dict = cpu_model.state_dict()
Expand Down Expand Up @@ -882,6 +895,32 @@ def setup_model(
model_config = AutoConfig.from_pretrained(model_name_or_path)
is_gpt_oss = is_gpt_oss_model(model_config)

# The Hub kernel for mamba-ssm is incompatible with causal_conv1d v1.6.0
# (different C API). Use local packages instead.
if getattr(model_config, "model_type", None) == "granitemoehybrid":
try:
from transformers.integrations.hub_kernels import _KERNEL_MODULE_MAPPING
import causal_conv1d
import mamba_ssm
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import (
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
)

mamba_ssm.selective_state_update = selective_state_update
mamba_ssm.mamba_chunk_scan_combined = mamba_chunk_scan_combined
mamba_ssm.mamba_split_conv1d_scan_combined = mamba_split_conv1d_scan_combined

_KERNEL_MODULE_MAPPING["causal-conv1d"] = causal_conv1d
_KERNEL_MODULE_MAPPING["mamba-ssm"] = mamba_ssm
log_rank_0("Using local mamba_ssm/causal_conv1d instead of Hub kernels")
except ImportError:
log_rank_0(
"mamba_ssm or causal_conv1d not installed; "
"GraniteMoeHybrid will use slow (torch) path"
)

# Set up quantization config for GPT-OSS models
if is_gpt_oss:
try:
Expand All @@ -896,13 +935,35 @@ def setup_model(
except ImportError:
log_rank_0("⚠️ GPT-OSS model detected but Mxfp4Config not available - using default config")

# Check if model uses M-RoPE (multimodal rotary position embeddings).
# Models with M-RoPE (e.g. Qwen3.5) pass 3D position_ids through kwargs which
# causes Flash Attention 2's _is_packed_sequence() to misinterpret them as packed
# sequences, leading to incorrect cu_seqlens computation and CUDA illegal memory
# access. Force SDPA for these models.
_uses_mrope = has_mrope(model_config)

# Check if flash_attn is available and set appropriate attention implementation
try:
import flash_attn as _ # noqa: F401

if is_gpt_oss:
base_model_args["attn_implementation"] = "kernels-community/vllm-flash-attn3"
log_rank_0("Set attention implementation to vllm-flash-attn3 for GPT-OSS")
if _uses_mrope:
base_model_args["attn_implementation"] = "sdpa"
log_rank_0(
f"Using SDPA for {model_name_or_path} (M-RoPE model incompatible with Flash Attention 2 varlen path)"
)
elif is_gpt_oss:
# vllm-flash-attn3 requires Hopper (SM 9.0+) GPUs;
# GPT-OSS only supports flash-attn3 or eager
major, _ = torch.cuda.get_device_capability(0)
if major >= 9:
base_model_args["attn_implementation"] = "kernels-community/vllm-flash-attn3"
log_rank_0("Set attention implementation to vllm-flash-attn3 for GPT-OSS")
else:
base_model_args["attn_implementation"] = "eager"
log_rank_0(
f"GPT-OSS: flash-attn3 requires Hopper (SM 9.0+) GPUs, "
f"but found SM {major}.x. Using eager attention instead."
)
else:
base_model_args["attn_implementation"] = "flash_attention_2"

Expand Down Expand Up @@ -961,7 +1022,10 @@ def load_standard_model():
)
else:
# non-distributed path: direct loading
model = ModelClass.from_pretrained(**base_model_args)
if is_vlm_with_causal_lm(model_config):
model = extract_causal_lm_from_vlm(model_name_or_path, base_model_args)
else:
model = ModelClass.from_pretrained(**base_model_args)
return align_model_and_tokenizer(model, tokenizer)

def load_osft_model():
Expand Down Expand Up @@ -1070,6 +1134,7 @@ def load_osft_model():
# List of supported architectures
if class_name not in [
"MistralForCausalLM",
"Ministral3ForCausalLM",
"GPTDolomiteForCausalLM",
"LlamaForCausalLM",
"Starcoder2ForCausalLM",
Expand All @@ -1080,6 +1145,7 @@ def load_osft_model():
"Qwen2ForCausalLM",
"Phi3ForCausalLM", # covers phi3 and phi4
"Qwen3ForCausalLM",
"Qwen3_5ForCausalLM",
]:
log_rank_0(
f"\033[38;2;255;255;0mWarning: Model class name: {class_name} is not in the list of supported models.\033[0m",
Expand Down
20 changes: 16 additions & 4 deletions src/mini_trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,28 @@ def init_distributed_environment():


def get_model_class_from_config(model_path):
"""Get the actual model class (not just the name) from a pretrained path."""
"""Get the actual model class (not just the name) from a pretrained path.

Note: vlm_utils.is_vlm_with_causal_lm() handles the broader VLM detection
(deciding *whether* to extract). This function resolves the model CLASS
regardless of VLM status, falling back to text_config when needed.
"""
# get the model class from config
# TODO: make the `trust_remote_code` setting configurable somehow
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
mapping = MODEL_FOR_CAUSAL_LM_MAPPING

config_class = config.__class__
if config_class not in mapping:
raise ValueError(f"Model class {config_class} not found in mapping {mapping}")
return mapping[config_class]
if config_class in mapping:
return mapping[config_class]

# Fallback: for VLM models that wrap a CausalLM text backbone
# (e.g. Mistral3Config wrapping Ministral3Config), check text_config
text_config = getattr(config, "text_config", None)
if text_config is not None and text_config.__class__ in mapping:
return mapping[text_config.__class__]

raise ValueError(f"Model class {config_class} not found in mapping {mapping}")


def destroy_distributed_environment():
Expand Down
147 changes: 147 additions & 0 deletions src/mini_trainer/vlm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""Utilities for detecting and extracting CausalLM text backbones from VLM models.

Vision-Language Models (VLMs) like Mistral3ForConditionalGeneration wrap a
CausalLM text backbone (e.g. Ministral3ForCausalLM). This module provides
helpers to detect that wrapping and extract the text backbone so mini-trainer
can treat it as a standard CausalLM for SFT / OSFT training.
"""

import torch
import torch.nn as nn
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING

from mini_trainer.utils import log_rank_0


def is_vlm_with_causal_lm(config) -> bool:
"""Check if a model config is a VLM wrapping a CausalLM text backbone.

Returns True only when the top-level config is NOT in the CausalLM
mapping but its nested ``text_config`` IS. Models that are directly
registered as CausalLM (even if they also have a text_config) return
False.

Args:
config: An already-loaded HuggingFace model config object.

Returns:
True if the model is a VLM wrapping a CausalLM text backbone.
"""
if config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING:
return False
text_config = getattr(config, "text_config", None)
return text_config is not None and text_config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING


def _find_text_backbone(vlm_model: nn.Module) -> nn.Module:
"""Auto-detect the text backbone inside a VLM model.

Tries well-known attribute names first (``language_model``,
``text_model``, ``llm``), then falls back to searching
``named_children`` for class names containing ``ForCausalLM`` or
``TextModel``.

Args:
vlm_model: The loaded VLM model.

Returns:
The text backbone module.

Raises:
ValueError: If no text backbone can be found.
"""
inner = vlm_model.model if hasattr(vlm_model, "model") else vlm_model

# Well-known attribute names
for attr_name in ("language_model", "text_model", "llm"):
if hasattr(inner, attr_name):
return getattr(inner, attr_name)

# Fallback: search named children for common class-name patterns
for name, child in inner.named_children():
cls_name = child.__class__.__name__
if "ForCausalLM" in cls_name or "TextModel" in cls_name:
return child

available = [name for name, _ in inner.named_children()]
raise ValueError(
f"Cannot find text backbone in {type(vlm_model).__name__}. "
f"Available sub-modules on inner model: {available}"
)


def extract_causal_lm_from_vlm(model_path: str, load_kwargs: dict) -> nn.Module:
"""Load a VLM and extract the CausalLM text backbone.

Loads the full VLM via ``AutoModelForImageTextToText``, auto-detects
the text backbone using :func:`_find_text_backbone`, then creates a
standalone CausalLM model by transferring weights.

Args:
model_path: HuggingFace model name or local path.
load_kwargs: Keyword arguments forwarded to ``from_pretrained``.

Returns:
A standalone CausalLM model with the VLM's text weights.
"""
from transformers import AutoConfig, AutoModelForImageTextToText

log_rank_0("🔄 VLM detected – loading full VLM to extract CausalLM text backbone")

# Filter out None quantization_config to avoid interfering with
# the model's built-in quantization handling (e.g. FP8 auto-dequant)
vlm_kwargs = {
k: v for k, v in load_kwargs.items()
if not (k == "quantization_config" and v is None)
}
vlm = AutoModelForImageTextToText.from_pretrained(model_path, **vlm_kwargs)

# Auto-detect text backbone
backbone = _find_text_backbone(vlm)

# Resolve text_config and create standalone CausalLM
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
text_config = config.text_config
causal_lm_class = MODEL_FOR_CAUSAL_LM_MAPPING[text_config.__class__]

log_rank_0(f" Extracting {causal_lm_class.__name__} from {type(vlm).__name__}")
text_model = causal_lm_class(text_config)

# Transfer backbone weights
text_model.model = backbone

# Transfer lm_head
if hasattr(vlm, "lm_head"):
text_model.lm_head = vlm.lm_head
else:
raise ValueError(f"Cannot extract lm_head from {type(vlm).__name__}")

del vlm
if torch.cuda.is_available():
torch.cuda.empty_cache()

log_rank_0(f" ✅ Extracted {causal_lm_class.__name__} successfully")
return text_model


def has_mrope(config) -> bool:
"""Check if a model config uses M-RoPE (multimodal rotary position embeddings).

Inspects both the top-level config and its ``text_config`` (if present)
for ``rope_scaling`` or ``rope_parameters`` dicts containing the
``mrope_section`` key.

Args:
config: An already-loaded HuggingFace model config object.

Returns:
True if M-RoPE is detected.
"""
for cfg in (config, getattr(config, "text_config", None)):
if cfg is None:
continue
for attr in ("rope_scaling", "rope_parameters"):
rope_dict = getattr(cfg, attr, None)
if isinstance(rope_dict, dict) and "mrope_section" in rope_dict:
return True
return False
Loading