Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions src/mini_trainer/osft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from mini_trainer.gpt_oss_utils import is_gpt_oss_model
from mini_trainer.utils import get_control_process_group, log_rank_0
from mini_trainer.vlm_utils import extract_causal_lm_from_vlm, is_vlm_with_causal_lm

# Memory optimization constants
OSFT_CACHE_CLEAR_INTERVAL = int(
Expand Down Expand Up @@ -315,6 +316,18 @@ def _reconstruct_weight(
# "experts.down_proj",
]
},
"qwen3_5": {
"patterns": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"linear_attn.out_proj",
"mlp.gate_proj",
"mlp.down_proj",
"mlp.up_proj",
]
},
"default": {
"patterns": [
"self_attn.q_proj",
Expand Down Expand Up @@ -346,8 +359,8 @@ def _reconstruct_weight(
"mistral": "mistral",
"granite": "granite",
"gpt2": "gpt2",
# Easy to add more mappings
# "phi": "phi",
"qwen3_5": "qwen3_5",
"qwen3.5": "qwen3_5",
}


Expand Down Expand Up @@ -767,12 +780,26 @@ def _load_model_memory_efficient(
if dist.get_rank() == 0:
with torch.no_grad():
log_rank_0(f"📥 Loading base model to CPU in {load_dtype}...")
base_model = base_model_class.from_pretrained(
pretrained_model_name_or_path,
*model_args,
**final_base_kwargs,

# Check if this is a VLM wrapping a CausalLM text backbone
from transformers import AutoConfig

_pre_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True
)

if is_vlm_with_causal_lm(_pre_config):
log_rank_0("🔄 VLM detected – extracting CausalLM text backbone for OSFT")
base_model = extract_causal_lm_from_vlm(
pretrained_model_name_or_path, final_base_kwargs
)
else:
base_model = base_model_class.from_pretrained(
pretrained_model_name_or_path,
*model_args,
**final_base_kwargs,
)

align_fn = osft_class_kwargs.get("lazy_init_tokenizer_align_fn")
if align_fn:
base_model = align_fn(base_model)
Expand Down Expand Up @@ -888,6 +915,10 @@ def create_osft_model_class(base_cls) -> type[OSFTModel]:
class ModelWithOSFT(base_cls):
osft_config: dict[str, int]

@classmethod
def _can_set_experts_implementation(cls):
return base_cls._can_set_experts_implementation()

def __init__(
self,
config,
Expand Down
132 changes: 122 additions & 10 deletions src/mini_trainer/setup_model_for_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@
log_rank_0,
patch_target_module,
)
from mini_trainer.vlm_utils import (
extract_causal_lm_from_vlm,
needs_sdpa,
has_timm_vision_tower,
is_vlm_for_direct_loading,
is_vlm_with_causal_lm,
load_vlm_for_text_training,
)


def _distributed_initialized() -> bool:
Expand Down Expand Up @@ -438,8 +446,13 @@ def wrap_fsdp2(model: torch.nn.Module) -> torch.nn.Module:
print(f"WARNING: Failed to disable HuggingFace cache for model {model.__class__.__name__}: {e}")

# Find the transformer block container
# Support common architectures: Llama (model.layers), GPT-2 (transformer.h)
if hasattr(model, "model") and hasattr(model.model, "layers"):
# Support common architectures:
# - VLM direct load (model.model.language_model.layers) e.g. Qwen3-VL
# - Llama-style (model.model.layers)
# - GPT-2-style (transformer.h)
if hasattr(model, "model") and hasattr(model.model, "language_model") and hasattr(model.model.language_model, "layers"):
layers = model.model.language_model.layers
elif hasattr(model, "model") and hasattr(model.model, "layers"):
layers = model.model.layers
elif hasattr(model, "transformer") and hasattr(model.transformer, "h"):
layers = model.transformer.h
Expand Down Expand Up @@ -817,12 +830,24 @@ def setup_sft_model_distributed(
state_dict = None
buffer_dict = None

# Check if this is a VLM wrapping a CausalLM text backbone, or a direct VLM
_model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
is_vlm = is_vlm_with_causal_lm(_model_config)
is_direct_vlm = is_vlm_for_direct_loading(_model_config)

if dist.get_rank() == 0:
log_rank_0("rank 0: loading model to CPU")
try:
with torch.no_grad():
# Default load targets CPU when no device_map or accelerate is present
cpu_model = ModelClass.from_pretrained(**base_model_args)
if is_vlm:
# VLM model: load full VLM and extract CausalLM text backbone
cpu_model = extract_causal_lm_from_vlm(model_name_or_path, base_model_args)
elif is_direct_vlm:
# VLM with no CausalLM class: load directly for text-only training
cpu_model = load_vlm_for_text_training(model_name_or_path, base_model_args)
else:
# Standard CausalLM: load directly
cpu_model = ModelClass.from_pretrained(**base_model_args)
cpu_model = align_model_and_tokenizer(cpu_model, tokenizer)
config = cpu_model.config
state_dict = cpu_model.state_dict()
Expand All @@ -845,8 +870,14 @@ def setup_sft_model_distributed(

# All ranks: Create model on meta device
log_rank_0("creating model on meta device")
with torch.device("meta"):
model = ModelClass.from_config(config)
if is_direct_vlm:
from transformers import AutoModelForImageTextToText

with torch.device("meta"):
model = AutoModelForImageTextToText.from_config(config)
else:
with torch.device("meta"):
model = ModelClass.from_config(config)

# Align model with tokenizer
model = align_model_and_tokenizer(model, tokenizer)
Expand Down Expand Up @@ -882,6 +913,32 @@ def setup_model(
model_config = AutoConfig.from_pretrained(model_name_or_path)
is_gpt_oss = is_gpt_oss_model(model_config)

# The Hub kernel for mamba-ssm is incompatible with causal_conv1d v1.6.0
# (different C API). Use local packages instead.
if getattr(model_config, "model_type", None) == "granitemoehybrid":
try:
from transformers.integrations.hub_kernels import _KERNEL_MODULE_MAPPING
import causal_conv1d
import mamba_ssm
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import (
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
)

mamba_ssm.selective_state_update = selective_state_update
mamba_ssm.mamba_chunk_scan_combined = mamba_chunk_scan_combined
mamba_ssm.mamba_split_conv1d_scan_combined = mamba_split_conv1d_scan_combined

_KERNEL_MODULE_MAPPING["causal-conv1d"] = causal_conv1d
_KERNEL_MODULE_MAPPING["mamba-ssm"] = mamba_ssm
log_rank_0("Using local mamba_ssm/causal_conv1d instead of Hub kernels")
except ImportError:
log_rank_0(
"mamba_ssm or causal_conv1d not installed; "
"GraniteMoeHybrid will use slow (torch) path"
)

# Set up quantization config for GPT-OSS models
if is_gpt_oss:
try:
Expand All @@ -896,13 +953,33 @@ def setup_model(
except ImportError:
log_rank_0("⚠️ GPT-OSS model detected but Mxfp4Config not available - using default config")

# Check if model requires SDPA instead of Flash Attention 2.
# This covers M-RoPE models (3D position_ids) and models with timm vision
# towers (TimmWrapperModel rejects flash_attention_2).
_needs_sdpa = needs_sdpa(model_config)

# Check if flash_attn is available and set appropriate attention implementation
try:
import flash_attn as _ # noqa: F401

if is_gpt_oss:
base_model_args["attn_implementation"] = "kernels-community/vllm-flash-attn3"
log_rank_0("Set attention implementation to vllm-flash-attn3 for GPT-OSS")
if _needs_sdpa:
base_model_args["attn_implementation"] = "sdpa"
log_rank_0(
f"Using SDPA for {model_name_or_path} (model incompatible with Flash Attention 2)"
)
elif is_gpt_oss:
# vllm-flash-attn3 requires Hopper (SM 9.0+) GPUs;
# GPT-OSS only supports flash-attn3 or eager
major, _ = torch.cuda.get_device_capability(0)
if major >= 9:
base_model_args["attn_implementation"] = "kernels-community/vllm-flash-attn3"
log_rank_0("Set attention implementation to vllm-flash-attn3 for GPT-OSS")
else:
base_model_args["attn_implementation"] = "eager"
log_rank_0(
f"GPT-OSS: flash-attn3 requires Hopper (SM 9.0+) GPUs, "
f"but found SM {major}.x. Using eager attention instead."
)
else:
base_model_args["attn_implementation"] = "flash_attention_2"

Expand All @@ -912,6 +989,20 @@ def setup_model(
else:
raise e

# For models with timm vision towers: set vision config to eager
# while keeping the text model's attention implementation.
# timm's TimmWrapperModel rejects both FA2 and SDPA.
if has_timm_vision_tower(model_config):
attn_impl = base_model_args.get("attn_implementation", "flash_attention_2")
base_model_args["attn_implementation"] = {
"text_config": attn_impl,
"vision_config": "eager",
}
log_rank_0(
f"Model has timm vision tower — using eager attention for vision, "
f"{attn_impl} for text model."
)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

# patch both loss functions, since models will use the regular HF
Expand Down Expand Up @@ -961,11 +1052,28 @@ def load_standard_model():
)
else:
# non-distributed path: direct loading
model = ModelClass.from_pretrained(**base_model_args)
if is_vlm_with_causal_lm(model_config):
model = extract_causal_lm_from_vlm(model_name_or_path, base_model_args)
elif is_vlm_for_direct_loading(model_config):
model = load_vlm_for_text_training(model_name_or_path, base_model_args)
else:
model = ModelClass.from_pretrained(**base_model_args)
return align_model_and_tokenizer(model, tokenizer)

def load_osft_model():
"""Load a model with OSFT (Orthogonal Subspace Fine-Tuning) support."""
# Direct VLMs (no CausalLM class) are not supported for OSFT yet.
# OSFT wraps the base model class and modifies internal weights, which
# requires a CausalLM-compatible architecture. Direct VLMs have a
# different layer structure (model.model.language_model.layers) that
# would need significant OSFT adapter changes.
if is_vlm_for_direct_loading(model_config):
raise ValueError(
f"OSFT is not supported for direct VLM models (e.g. {model_name_or_path}). "
"This model has no standalone CausalLM class and cannot be wrapped by OSFT. "
"Use SFT training instead (osft=False)."
)

log_rank_0("loading OSFT model")
# If osft_output_dtype is not specified, use train_dtype for consistency
effective_osft_output_dtype = osft_output_dtype if osft_output_dtype is not None else train_dtype
Expand Down Expand Up @@ -1070,6 +1178,7 @@ def load_osft_model():
# List of supported architectures
if class_name not in [
"MistralForCausalLM",
"Ministral3ForCausalLM",
"GPTDolomiteForCausalLM",
"LlamaForCausalLM",
"Starcoder2ForCausalLM",
Expand All @@ -1080,6 +1189,9 @@ def load_osft_model():
"Qwen2ForCausalLM",
"Phi3ForCausalLM", # covers phi3 and phi4
"Qwen3ForCausalLM",
"Qwen3_5ForCausalLM",
"Qwen3VLForConditionalGeneration", # direct VLM loading (no CausalLM class)
"Gemma3nForConditionalGeneration", # dual-registered VLM loaded as CausalLM
]:
log_rank_0(
f"\033[38;2;255;255;0mWarning: Model class name: {class_name} is not in the list of supported models.\033[0m",
Expand Down
27 changes: 23 additions & 4 deletions src/mini_trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,35 @@ def init_distributed_environment():


def get_model_class_from_config(model_path):
"""Get the actual model class (not just the name) from a pretrained path."""
"""Get the actual model class (not just the name) from a pretrained path.

Note: vlm_utils.is_vlm_with_causal_lm() handles the broader VLM detection
(deciding *whether* to extract). This function resolves the model CLASS
regardless of VLM status, falling back to text_config when needed.
"""
# get the model class from config
# TODO: make the `trust_remote_code` setting configurable somehow
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
mapping = MODEL_FOR_CAUSAL_LM_MAPPING

config_class = config.__class__
if config_class not in mapping:
raise ValueError(f"Model class {config_class} not found in mapping {mapping}")
return mapping[config_class]
if config_class in mapping:
return mapping[config_class]

# Fallback: for VLM models that wrap a CausalLM text backbone
# (e.g. Mistral3Config wrapping Ministral3Config), check text_config
text_config = getattr(config, "text_config", None)
if text_config is not None and text_config.__class__ in mapping:
return mapping[text_config.__class__]

# Fallback: for VLMs with no CausalLM class at all (e.g. Qwen3-VL-2B),
# check the ImageTextToText mapping so they can be loaded directly.
from transformers.models.auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING

if config_class in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING:
return MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING[config_class]

raise ValueError(f"Model class {config_class} not found in mapping {mapping}")


def destroy_distributed_environment():
Expand Down
Loading
Loading