Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions src/mini_trainer/osft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,14 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
set_model_state_dict,
)
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
from tqdm import tqdm
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM

from mini_trainer.fsdp2_lazy_init import (
FSDP2_LAZY_INIT_OSFT,
get_fsdp2_lazy_init_mode,
set_fsdp2_lazy_init_mode,
)
from mini_trainer.fsdp2_lazy_init import FSDP2_LAZY_INIT_OSFT, get_fsdp2_lazy_init_mode, set_fsdp2_lazy_init_mode
from mini_trainer.gpt_oss_utils import is_gpt_oss_model
from mini_trainer.utils import get_control_process_group, log_rank_0
from mini_trainer.vlm_utils import extract_causal_lm_from_vlm, is_vlm_with_causal_lm

# Memory optimization constants
OSFT_CACHE_CLEAR_INTERVAL = int(
Expand Down Expand Up @@ -315,6 +309,18 @@ def _reconstruct_weight(
# "experts.down_proj",
]
},
"qwen3_5": {
"patterns": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"linear_attn.out_proj",
"mlp.gate_proj",
"mlp.down_proj",
"mlp.up_proj",
]
},
"default": {
"patterns": [
"self_attn.q_proj",
Expand All @@ -337,6 +343,8 @@ def _reconstruct_weight(
"gptneo": "gpt-neo", # Handle both "gpt-neo" and "gptneo" variants
"gpt-oss": "gpt-oss",
"opt": "opt",
"qwen3_5": "qwen3_5", # specific BEFORE generic
"qwen3.5": "qwen3_5", # specific BEFORE generic
"qwen": "qwen",
"gemma": "gemma",
"phi4": "phi3",
Expand All @@ -346,8 +354,6 @@ def _reconstruct_weight(
"mistral": "mistral",
"granite": "granite",
"gpt2": "gpt2",
# Easy to add more mappings
# "phi": "phi",
}


Expand Down Expand Up @@ -767,11 +773,30 @@ def _load_model_memory_efficient(
if dist.get_rank() == 0:
with torch.no_grad():
log_rank_0(f"📥 Loading base model to CPU in {load_dtype}...")
base_model = base_model_class.from_pretrained(
pretrained_model_name_or_path,
*model_args,
**final_base_kwargs,
)

# Check if this is a VLM wrapping a CausalLM text backbone
_is_vlm = False
try:
from transformers import AutoConfig

_pre_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
_is_vlm = is_vlm_with_causal_lm(_pre_config)
except (OSError, ValueError):
# Config loading can fail for local-only or mock paths
pass

if _is_vlm:
log_rank_0("🔄 VLM detected – extracting CausalLM text backbone for OSFT")
# Filter out pretrained_model_name_or_path to avoid duplicate
# argument since it's passed positionally to extract_causal_lm_from_vlm
vlm_kwargs = {k: v for k, v in final_base_kwargs.items() if k != "pretrained_model_name_or_path"}
base_model = extract_causal_lm_from_vlm(pretrained_model_name_or_path, vlm_kwargs)
else:
base_model = base_model_class.from_pretrained(
pretrained_model_name_or_path,
*model_args,
**final_base_kwargs,
)

align_fn = osft_class_kwargs.get("lazy_init_tokenizer_align_fn")
if align_fn:
Expand Down Expand Up @@ -888,6 +913,12 @@ def create_osft_model_class(base_cls) -> type[OSFTModel]:
class ModelWithOSFT(base_cls):
osft_config: dict[str, int]

@classmethod
def _can_set_experts_implementation(cls):
if hasattr(base_cls, "_can_set_experts_implementation"):
return base_cls._can_set_experts_implementation()
return False

def __init__(
self,
config,
Expand Down
5 changes: 1 addition & 4 deletions src/mini_trainer/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@
import numpy as np
import torch
import torch.distributed as dist
from datasets import (
Dataset as HFDataset,
load_dataset,
)
from datasets import Dataset as HFDataset, load_dataset
from deprecated import deprecated
from torch.utils.data import DataLoader, Dataset, Sampler, SequentialSampler

Expand Down
Loading
Loading