Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions src/mini_trainer/osft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,14 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
set_model_state_dict,
)
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
from tqdm import tqdm
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM

from mini_trainer.fsdp2_lazy_init import (
FSDP2_LAZY_INIT_OSFT,
get_fsdp2_lazy_init_mode,
set_fsdp2_lazy_init_mode,
)
from mini_trainer.fsdp2_lazy_init import FSDP2_LAZY_INIT_OSFT, get_fsdp2_lazy_init_mode, set_fsdp2_lazy_init_mode
from mini_trainer.gpt_oss_utils import is_gpt_oss_model
from mini_trainer.utils import get_control_process_group, log_rank_0
from mini_trainer.vlm_utils import extract_causal_lm_from_vlm, is_vlm_with_causal_lm

# Memory optimization constants
OSFT_CACHE_CLEAR_INTERVAL = int(
Expand Down Expand Up @@ -315,6 +309,18 @@ def _reconstruct_weight(
# "experts.down_proj",
]
},
"qwen3_5": {
"patterns": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"linear_attn.out_proj",
"mlp.gate_proj",
"mlp.down_proj",
"mlp.up_proj",
]
},
"default": {
"patterns": [
"self_attn.q_proj",
Expand All @@ -337,6 +343,8 @@ def _reconstruct_weight(
"gptneo": "gpt-neo", # Handle both "gpt-neo" and "gptneo" variants
"gpt-oss": "gpt-oss",
"opt": "opt",
"qwen3_5": "qwen3_5", # specific BEFORE generic
"qwen3.5": "qwen3_5", # specific BEFORE generic
"qwen": "qwen",
"gemma": "gemma",
"phi4": "phi3",
Expand All @@ -346,8 +354,6 @@ def _reconstruct_weight(
"mistral": "mistral",
"granite": "granite",
"gpt2": "gpt2",
# Easy to add more mappings
# "phi": "phi",
}


Expand Down Expand Up @@ -767,11 +773,30 @@ def _load_model_memory_efficient(
if dist.get_rank() == 0:
with torch.no_grad():
log_rank_0(f"📥 Loading base model to CPU in {load_dtype}...")
base_model = base_model_class.from_pretrained(
pretrained_model_name_or_path,
*model_args,
**final_base_kwargs,
)

# Check if this is a VLM wrapping a CausalLM text backbone
_is_vlm = False
try:
from transformers import AutoConfig

_pre_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
_is_vlm = is_vlm_with_causal_lm(_pre_config)
except (OSError, ValueError):
# Config loading can fail for local-only or mock paths
pass

if _is_vlm:
log_rank_0("🔄 VLM detected – extracting CausalLM text backbone for OSFT")
# Filter out pretrained_model_name_or_path to avoid duplicate
# argument since it's passed positionally to extract_causal_lm_from_vlm
vlm_kwargs = {k: v for k, v in final_base_kwargs.items() if k != "pretrained_model_name_or_path"}
base_model = extract_causal_lm_from_vlm(pretrained_model_name_or_path, vlm_kwargs)
else:
base_model = base_model_class.from_pretrained(
pretrained_model_name_or_path,
*model_args,
**final_base_kwargs,
)

align_fn = osft_class_kwargs.get("lazy_init_tokenizer_align_fn")
if align_fn:
Expand Down Expand Up @@ -888,6 +913,12 @@ def create_osft_model_class(base_cls) -> type[OSFTModel]:
class ModelWithOSFT(base_cls):
osft_config: dict[str, int]

@classmethod
def _can_set_experts_implementation(cls):
if hasattr(base_cls, "_can_set_experts_implementation"):
return base_cls._can_set_experts_implementation()
return False

def __init__(
self,
config,
Expand Down
5 changes: 1 addition & 4 deletions src/mini_trainer/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@
import numpy as np
import torch
import torch.distributed as dist
from datasets import (
Dataset as HFDataset,
load_dataset,
)
from datasets import Dataset as HFDataset, load_dataset
from deprecated import deprecated
from torch.utils.data import DataLoader, Dataset, Sampler, SequentialSampler

Expand Down
174 changes: 133 additions & 41 deletions src/mini_trainer/setup_model_for_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,8 @@

import torch
import torch.distributed as dist
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
)
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
set_model_state_dict,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
from transformers import AutoConfig, AutoTokenizer, Mxfp4Config
Expand All @@ -26,16 +21,15 @@
set_fsdp2_lazy_init_mode,
)
from mini_trainer.gpt_oss_utils import freeze_router_params, is_gpt_oss_model
from mini_trainer.osft_utils import (
OSFTModel,
_build_osft_kwargs,
_set_osft_dtypes,
create_osft_model_class,
)
from mini_trainer.utils import (
get_model_class_from_config,
log_rank_0,
patch_target_module,
from mini_trainer.osft_utils import OSFTModel, _build_osft_kwargs, _set_osft_dtypes, create_osft_model_class
from mini_trainer.utils import get_model_class_from_config, log_rank_0, patch_target_module
from mini_trainer.vlm_utils import (
extract_causal_lm_from_vlm,
has_timm_vision_tower,
is_vlm_for_direct_loading,
is_vlm_with_causal_lm,
load_vlm_for_text_training,
needs_sdpa,
)


Expand Down Expand Up @@ -66,10 +60,7 @@ def _apply_liger_kernels_if_requested(use_liger_kernels, model_config, base_mode
return

try:
from liger_kernel.transformers.monkey_patch import (
MODEL_TYPE_TO_APPLY_LIGER_FN,
_apply_liger_kernel,
)
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN, _apply_liger_kernel
except ImportError as e:
raise ImportError(
"Tried to use liger kernels for OSFT, but they are not installed. "
Expand Down Expand Up @@ -438,8 +429,17 @@ def wrap_fsdp2(model: torch.nn.Module) -> torch.nn.Module:
print(f"WARNING: Failed to disable HuggingFace cache for model {model.__class__.__name__}: {e}")

# Find the transformer block container
# Support common architectures: Llama (model.layers), GPT-2 (transformer.h)
if hasattr(model, "model") and hasattr(model.model, "layers"):
# Support common architectures:
# - VLM direct load (model.model.language_model.layers) e.g. Qwen3-VL
# - Llama-style (model.model.layers)
# - GPT-2-style (transformer.h)
if (
hasattr(model, "model")
and hasattr(model.model, "language_model")
and hasattr(model.model.language_model, "layers")
):
layers = model.model.language_model.layers
elif hasattr(model, "model") and hasattr(model.model, "layers"):
layers = model.model.layers
elif hasattr(model, "transformer") and hasattr(model.transformer, "h"):
layers = model.transformer.h
Expand Down Expand Up @@ -817,12 +817,24 @@ def setup_sft_model_distributed(
state_dict = None
buffer_dict = None

# Check if this is a VLM wrapping a CausalLM text backbone, or a direct VLM
_model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
is_vlm = is_vlm_with_causal_lm(_model_config)
is_direct_vlm = is_vlm_for_direct_loading(_model_config)

if dist.get_rank() == 0:
log_rank_0("rank 0: loading model to CPU")
try:
with torch.no_grad():
# Default load targets CPU when no device_map or accelerate is present
cpu_model = ModelClass.from_pretrained(**base_model_args)
if is_vlm:
# VLM model: load full VLM and extract CausalLM text backbone
cpu_model = extract_causal_lm_from_vlm(model_name_or_path, base_model_args)
elif is_direct_vlm:
# VLM with no CausalLM class: load directly for text-only training
cpu_model = load_vlm_for_text_training(model_name_or_path, base_model_args)
else:
# Standard CausalLM: load directly
cpu_model = ModelClass.from_pretrained(**base_model_args)
cpu_model = align_model_and_tokenizer(cpu_model, tokenizer)
config = cpu_model.config
state_dict = cpu_model.state_dict()
Expand All @@ -845,8 +857,14 @@ def setup_sft_model_distributed(

# All ranks: Create model on meta device
log_rank_0("creating model on meta device")
with torch.device("meta"):
model = ModelClass.from_config(config)
if is_direct_vlm:
from transformers import AutoModelForImageTextToText

with torch.device("meta"):
model = AutoModelForImageTextToText.from_config(config)
else:
with torch.device("meta"):
model = ModelClass.from_config(config)

# Align model with tokenizer
model = align_model_and_tokenizer(model, tokenizer)
Expand Down Expand Up @@ -882,6 +900,28 @@ def setup_model(
model_config = AutoConfig.from_pretrained(model_name_or_path)
is_gpt_oss = is_gpt_oss_model(model_config)

# Pre-populate the transformers Hub kernel cache with locally installed
# mamba_ssm and causal_conv1d packages. The Hub kernel versions may be
# compiled against a different PyTorch/CUDA build, causing runtime errors.
# Using local packages ensures ABI compatibility.
if getattr(model_config, "model_type", None) == "granitemoehybrid":
try:
import causal_conv1d
import mamba_ssm
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
from transformers.integrations.hub_kernels import _KERNEL_MODULE_MAPPING

mamba_ssm.selective_state_update = selective_state_update
mamba_ssm.mamba_chunk_scan_combined = mamba_chunk_scan_combined
mamba_ssm.mamba_split_conv1d_scan_combined = mamba_split_conv1d_scan_combined

_KERNEL_MODULE_MAPPING["causal-conv1d"] = causal_conv1d
_KERNEL_MODULE_MAPPING["mamba-ssm"] = mamba_ssm
log_rank_0("Using local mamba_ssm/causal_conv1d instead of Hub kernels")
except (ImportError, AttributeError) as e:
log_rank_0(f"Could not patch mamba kernels ({e}); GraniteMoeHybrid may use Hub kernels")

# Set up quantization config for GPT-OSS models
if is_gpt_oss:
try:
Expand All @@ -896,21 +936,52 @@ def setup_model(
except ImportError:
log_rank_0("⚠️ GPT-OSS model detected but Mxfp4Config not available - using default config")

# Check if flash_attn is available and set appropriate attention implementation
try:
import flash_attn as _ # noqa: F401
# Check if model requires SDPA instead of Flash Attention 2.
# This covers M-RoPE models (3D position_ids) and models with timm vision
# towers (TimmWrapperModel rejects flash_attention_2).
_needs_sdpa = needs_sdpa(model_config)

if is_gpt_oss:
base_model_args["attn_implementation"] = "kernels-community/vllm-flash-attn3"
log_rank_0("Set attention implementation to vllm-flash-attn3 for GPT-OSS")
else:
base_model_args["attn_implementation"] = "flash_attention_2"
# Handle models that need SDPA (doesn't require flash_attn)
if _needs_sdpa:
base_model_args["attn_implementation"] = "sdpa"
log_rank_0(f"Using SDPA for {model_name_or_path} (model incompatible with Flash Attention 2)")
else:
# Check if flash_attn is available for non-SDPA models
try:
import flash_attn as _

if is_gpt_oss:
# vllm-flash-attn3 requires Hopper (SM 9.0+) GPUs;
# GPT-OSS only supports flash-attn3 or eager
major, _ = torch.cuda.get_device_capability(0)
if major >= 9:
base_model_args["attn_implementation"] = "kernels-community/vllm-flash-attn3"
log_rank_0("Set attention implementation to vllm-flash-attn3 for GPT-OSS")
else:
base_model_args["attn_implementation"] = "eager"
log_rank_0(
f"GPT-OSS: flash-attn3 requires Hopper (SM 9.0+) GPUs, "
f"but found SM {major}.x. Using eager attention instead."
)
else:
base_model_args["attn_implementation"] = "flash_attention_2"

except ImportError as e:
if os.environ.get("TESTING", "false").lower() == "true":
base_model_args["attn_implementation"] = "sdpa"
else:
raise e
except ImportError as e:
if os.environ.get("TESTING", "false").lower() == "true":
base_model_args["attn_implementation"] = "sdpa"
else:
raise e

# For models with timm vision towers: set vision config to eager
# while keeping the text model's attention implementation.
# timm's TimmWrapperModel rejects both FA2 and SDPA.
if has_timm_vision_tower(model_config):
attn_impl = base_model_args.get("attn_implementation", "flash_attention_2")
base_model_args["attn_implementation"] = {
"text_config": attn_impl,
"vision_config": "eager",
}
log_rank_0(f"Model has timm vision tower — using eager attention for vision, {attn_impl} for text model.")

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

Expand Down Expand Up @@ -961,11 +1032,28 @@ def load_standard_model():
)
else:
# non-distributed path: direct loading
model = ModelClass.from_pretrained(**base_model_args)
if is_vlm_with_causal_lm(model_config):
model = extract_causal_lm_from_vlm(model_name_or_path, base_model_args)
elif is_vlm_for_direct_loading(model_config):
model = load_vlm_for_text_training(model_name_or_path, base_model_args)
else:
model = ModelClass.from_pretrained(**base_model_args)
return align_model_and_tokenizer(model, tokenizer)

def load_osft_model():
"""Load a model with OSFT (Orthogonal Subspace Fine-Tuning) support."""
# Direct VLMs (no CausalLM class) are not supported for OSFT yet.
# OSFT wraps the base model class and modifies internal weights, which
# requires a CausalLM-compatible architecture. Direct VLMs have a
# different layer structure (model.model.language_model.layers) that
# would need significant OSFT adapter changes.
if is_vlm_for_direct_loading(model_config):
raise ValueError(
f"OSFT is not supported for direct VLM models (e.g. {model_name_or_path}). "
"This model has no standalone CausalLM class and cannot be wrapped by OSFT. "
"Use SFT training instead (osft=False)."
)

log_rank_0("loading OSFT model")
# If osft_output_dtype is not specified, use train_dtype for consistency
effective_osft_output_dtype = osft_output_dtype if osft_output_dtype is not None else train_dtype
Expand Down Expand Up @@ -1070,6 +1158,7 @@ def load_osft_model():
# List of supported architectures
if class_name not in [
"MistralForCausalLM",
"Ministral3ForCausalLM",
"GPTDolomiteForCausalLM",
"LlamaForCausalLM",
"Starcoder2ForCausalLM",
Expand All @@ -1080,6 +1169,9 @@ def load_osft_model():
"Qwen2ForCausalLM",
"Phi3ForCausalLM", # covers phi3 and phi4
"Qwen3ForCausalLM",
"Qwen3_5ForCausalLM",
"Qwen3VLForConditionalGeneration", # direct VLM loading (no CausalLM class)
"Gemma3nForConditionalGeneration", # dual-registered VLM loaded as CausalLM
]:
log_rank_0(
f"\033[38;2;255;255;0mWarning: Model class name: {class_name} is not in the list of supported models.\033[0m",
Expand Down
Loading
Loading