Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions src/mini_trainer/osft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,18 @@ def _reconstruct_weight(
# "experts.down_proj",
]
},
"qwen3_5": {
"patterns": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"linear_attn.out_proj",
"mlp.gate_proj",
"mlp.down_proj",
"mlp.up_proj",
]
},
"default": {
"patterns": [
"self_attn.q_proj",
Expand Down Expand Up @@ -346,8 +358,8 @@ def _reconstruct_weight(
"mistral": "mistral",
"granite": "granite",
"gpt2": "gpt2",
# Easy to add more mappings
# "phi": "phi",
"qwen3_5": "qwen3_5",
"qwen3.5": "qwen3_5",
}


Expand Down Expand Up @@ -767,12 +779,52 @@ def _load_model_memory_efficient(
if dist.get_rank() == 0:
with torch.no_grad():
log_rank_0(f"📥 Loading base model to CPU in {load_dtype}...")
base_model = base_model_class.from_pretrained(
pretrained_model_name_or_path,
*model_args,
**final_base_kwargs,

# Check if this is a VLM wrapping a CausalLM text backbone
from transformers import AutoConfig
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING

_pre_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True
)
_text_cfg = getattr(_pre_config, "text_config", None)
_is_vlm = (
_text_cfg is not None
and _text_cfg.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING
and _pre_config.__class__ not in MODEL_FOR_CAUSAL_LM_MAPPING
)

if _is_vlm:
# VLM model: load full VLM and extract CausalLM text backbone
from transformers import AutoModelForImageTextToText

log_rank_0("🔄 VLM detected – extracting CausalLM text backbone for OSFT")
# Filter out None quantization_config to avoid interfering with
# the model's built-in quantization handling (e.g. FP8 auto-dequant)
vlm_kwargs = {
k: v for k, v in final_base_kwargs.items()
if not (k == "quantization_config" and v is None)
}
vlm = AutoModelForImageTextToText.from_pretrained(
pretrained_model_name_or_path,
*model_args,
**vlm_kwargs,
)
causal_lm_class = MODEL_FOR_CAUSAL_LM_MAPPING[_text_cfg.__class__]
base_model = causal_lm_class(_text_cfg)
base_model.model = vlm.model.language_model
base_model.lm_head = vlm.lm_head
del vlm
if torch.cuda.is_available():
torch.cuda.empty_cache()
log_rank_0(f" ✅ Extracted {causal_lm_class.__name__}")
else:
base_model = base_model_class.from_pretrained(
pretrained_model_name_or_path,
*model_args,
**final_base_kwargs,
)

align_fn = osft_class_kwargs.get("lazy_init_tokenizer_align_fn")
if align_fn:
base_model = align_fn(base_model)
Expand Down Expand Up @@ -888,6 +940,10 @@ def create_osft_model_class(base_cls) -> type[OSFTModel]:
class ModelWithOSFT(base_cls):
osft_config: dict[str, int]

@classmethod
def _can_set_experts_implementation(cls):
return base_cls._can_set_experts_implementation()

def __init__(
self,
config,
Expand Down
145 changes: 139 additions & 6 deletions src/mini_trainer/setup_model_for_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,83 @@
_set_osft_dtypes,
create_osft_model_class,
)
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING

from mini_trainer.utils import (
get_model_class_from_config,
log_rank_0,
patch_target_module,
)


def _is_vlm_wrapping_causal_lm(config) -> bool:
"""Check if a model config is a VLM config wrapping a CausalLM text backbone.

Some models (e.g. Ministral-3-3B) use a VLM architecture
(Mistral3ForConditionalGeneration) as a wrapper around a CausalLM text
model (Ministral3ForCausalLM). The top-level config is NOT in the
MODEL_FOR_CAUSAL_LM_MAPPING but the nested text_config IS.
"""
if config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING:
return False
text_config = getattr(config, "text_config", None)
return text_config is not None and text_config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING


def _extract_causal_lm_from_vlm(model_path: str, load_kwargs: dict) -> torch.nn.Module:
"""Load a VLM model and extract the CausalLM text backbone.

Loads the full VLM (e.g. Mistral3ForConditionalGeneration), then creates a
standalone CausalLM model (e.g. Ministral3ForCausalLM) by transferring the
language_model and lm_head weights.

Args:
model_path: HuggingFace model name or path.
load_kwargs: Keyword arguments forwarded to from_pretrained.

Returns:
A standalone CausalLM model.
"""
from transformers import AutoModelForImageTextToText

log_rank_0(f"🔄 VLM detected – loading full VLM to extract CausalLM text backbone")
# Filter out None quantization_config to avoid interfering with
# the model's built-in quantization handling (e.g. FP8 auto-dequant)
vlm_kwargs = {
k: v for k, v in load_kwargs.items()
if not (k == "quantization_config" and v is None)
}
vlm = AutoModelForImageTextToText.from_pretrained(model_path, **vlm_kwargs)

config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
text_config = config.text_config
causal_lm_class = MODEL_FOR_CAUSAL_LM_MAPPING[text_config.__class__]

log_rank_0(f" Extracting {causal_lm_class.__name__} from {type(vlm).__name__}")
text_model = causal_lm_class(text_config)

# Transfer weights: VLM stores the text backbone at model.language_model
if hasattr(vlm, "model") and hasattr(vlm.model, "language_model"):
text_model.model = vlm.model.language_model
else:
raise ValueError(
f"Cannot extract language model from {type(vlm).__name__}: "
f"expected vlm.model.language_model attribute"
)

if hasattr(vlm, "lm_head"):
text_model.lm_head = vlm.lm_head
else:
raise ValueError(f"Cannot extract lm_head from {type(vlm).__name__}")

del vlm
if torch.cuda.is_available():
torch.cuda.empty_cache()

log_rank_0(f" ✅ Extracted {causal_lm_class.__name__} successfully")
return text_model


def _distributed_initialized() -> bool:
"""
Returns True when torch.distributed is both available and initialized.
Expand Down Expand Up @@ -817,12 +887,20 @@ def setup_sft_model_distributed(
state_dict = None
buffer_dict = None

# Check if this is a VLM wrapping a CausalLM text backbone
_model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
is_vlm = _is_vlm_wrapping_causal_lm(_model_config)

if dist.get_rank() == 0:
log_rank_0("rank 0: loading model to CPU")
try:
with torch.no_grad():
# Default load targets CPU when no device_map or accelerate is present
cpu_model = ModelClass.from_pretrained(**base_model_args)
if is_vlm:
# VLM model: load full VLM and extract CausalLM text backbone
cpu_model = _extract_causal_lm_from_vlm(model_name_or_path, base_model_args)
else:
# Standard CausalLM: load directly
cpu_model = ModelClass.from_pretrained(**base_model_args)
cpu_model = align_model_and_tokenizer(cpu_model, tokenizer)
config = cpu_model.config
state_dict = cpu_model.state_dict()
Expand Down Expand Up @@ -882,6 +960,32 @@ def setup_model(
model_config = AutoConfig.from_pretrained(model_name_or_path)
is_gpt_oss = is_gpt_oss_model(model_config)

# The Hub kernel for mamba-ssm is incompatible with causal_conv1d v1.6.0
# (different C API). Use local packages instead.
if getattr(model_config, "model_type", None) == "granitemoehybrid":
try:
from transformers.integrations.hub_kernels import _KERNEL_MODULE_MAPPING
import causal_conv1d
import mamba_ssm
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import (
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
)

mamba_ssm.selective_state_update = selective_state_update
mamba_ssm.mamba_chunk_scan_combined = mamba_chunk_scan_combined
mamba_ssm.mamba_split_conv1d_scan_combined = mamba_split_conv1d_scan_combined

_KERNEL_MODULE_MAPPING["causal-conv1d"] = causal_conv1d
_KERNEL_MODULE_MAPPING["mamba-ssm"] = mamba_ssm
log_rank_0("Using local mamba_ssm/causal_conv1d instead of Hub kernels")
except ImportError:
log_rank_0(
"mamba_ssm or causal_conv1d not installed; "
"GraniteMoeHybrid will use slow (torch) path"
)

# Set up quantization config for GPT-OSS models
if is_gpt_oss:
try:
Expand All @@ -896,13 +1000,37 @@ def setup_model(
except ImportError:
log_rank_0("⚠️ GPT-OSS model detected but Mxfp4Config not available - using default config")

# Check if model uses M-RoPE (multimodal rotary position embeddings).
# Models with M-RoPE (e.g. Qwen3.5) pass 3D position_ids through kwargs which
# causes Flash Attention 2's _is_packed_sequence() to misinterpret them as packed
# sequences, leading to incorrect cu_seqlens computation and CUDA illegal memory
# access. Force SDPA for these models.
_text_config = getattr(model_config, "text_config", model_config)
_rope_params = getattr(_text_config, "rope_parameters", {}) or {}
_uses_mrope = "mrope_section" in _rope_params

# Check if flash_attn is available and set appropriate attention implementation
try:
import flash_attn as _ # noqa: F401

if is_gpt_oss:
base_model_args["attn_implementation"] = "kernels-community/vllm-flash-attn3"
log_rank_0("Set attention implementation to vllm-flash-attn3 for GPT-OSS")
if _uses_mrope:
base_model_args["attn_implementation"] = "sdpa"
log_rank_0(
f"Using SDPA for {model_name_or_path} (M-RoPE model incompatible with Flash Attention 2 varlen path)"
)
elif is_gpt_oss:
# vllm-flash-attn3 requires Hopper (SM 9.0+) GPUs;
# GPT-OSS only supports flash-attn3 or eager
major, _ = torch.cuda.get_device_capability(0)
if major >= 9:
base_model_args["attn_implementation"] = "kernels-community/vllm-flash-attn3"
log_rank_0("Set attention implementation to vllm-flash-attn3 for GPT-OSS")
else:
base_model_args["attn_implementation"] = "eager"
log_rank_0(
f"GPT-OSS: flash-attn3 requires Hopper (SM 9.0+) GPUs, "
f"but found SM {major}.x. Using eager attention instead."
)
else:
base_model_args["attn_implementation"] = "flash_attention_2"

Expand Down Expand Up @@ -961,7 +1089,10 @@ def load_standard_model():
)
else:
# non-distributed path: direct loading
model = ModelClass.from_pretrained(**base_model_args)
if _is_vlm_wrapping_causal_lm(model_config):
model = _extract_causal_lm_from_vlm(model_name_or_path, base_model_args)
else:
model = ModelClass.from_pretrained(**base_model_args)
return align_model_and_tokenizer(model, tokenizer)

def load_osft_model():
Expand Down Expand Up @@ -1070,6 +1201,7 @@ def load_osft_model():
# List of supported architectures
if class_name not in [
"MistralForCausalLM",
"Ministral3ForCausalLM",
"GPTDolomiteForCausalLM",
"LlamaForCausalLM",
"Starcoder2ForCausalLM",
Expand All @@ -1080,6 +1212,7 @@ def load_osft_model():
"Qwen2ForCausalLM",
"Phi3ForCausalLM", # covers phi3 and phi4
"Qwen3ForCausalLM",
"Qwen3_5ForCausalLM",
]:
log_rank_0(
f"\033[38;2;255;255;0mWarning: Model class name: {class_name} is not in the list of supported models.\033[0m",
Expand Down
13 changes: 10 additions & 3 deletions src/mini_trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,16 @@ def get_model_class_from_config(model_path):
mapping = MODEL_FOR_CAUSAL_LM_MAPPING

config_class = config.__class__
if config_class not in mapping:
raise ValueError(f"Model class {config_class} not found in mapping {mapping}")
return mapping[config_class]
if config_class in mapping:
return mapping[config_class]

# Fallback: for VLM models that wrap a CausalLM text backbone
# (e.g. Mistral3Config wrapping Ministral3Config), check text_config
text_config = getattr(config, "text_config", None)
if text_config is not None and text_config.__class__ in mapping:
return mapping[text_config.__class__]

raise ValueError(f"Model class {config_class} not found in mapping {mapping}")


def destroy_distributed_environment():
Expand Down
Loading