diff --git a/src/mini_trainer/osft_utils.py b/src/mini_trainer/osft_utils.py index 63d0837..1c9753a 100644 --- a/src/mini_trainer/osft_utils.py +++ b/src/mini_trainer/osft_utils.py @@ -10,20 +10,14 @@ import torch import torch.distributed as dist import torch.nn as nn -from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - set_model_state_dict, -) +from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict from tqdm import tqdm from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM -from mini_trainer.fsdp2_lazy_init import ( - FSDP2_LAZY_INIT_OSFT, - get_fsdp2_lazy_init_mode, - set_fsdp2_lazy_init_mode, -) +from mini_trainer.fsdp2_lazy_init import FSDP2_LAZY_INIT_OSFT, get_fsdp2_lazy_init_mode, set_fsdp2_lazy_init_mode from mini_trainer.gpt_oss_utils import is_gpt_oss_model from mini_trainer.utils import get_control_process_group, log_rank_0 +from mini_trainer.vlm_utils import extract_causal_lm_from_vlm, is_vlm_with_causal_lm # Memory optimization constants OSFT_CACHE_CLEAR_INTERVAL = int( @@ -315,6 +309,18 @@ def _reconstruct_weight( # "experts.down_proj", ] }, + "qwen3_5": { + "patterns": [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.o_proj", + "linear_attn.out_proj", + "mlp.gate_proj", + "mlp.down_proj", + "mlp.up_proj", + ] + }, "default": { "patterns": [ "self_attn.q_proj", @@ -337,6 +343,8 @@ def _reconstruct_weight( "gptneo": "gpt-neo", # Handle both "gpt-neo" and "gptneo" variants "gpt-oss": "gpt-oss", "opt": "opt", + "qwen3_5": "qwen3_5", # specific BEFORE generic + "qwen3.5": "qwen3_5", # specific BEFORE generic "qwen": "qwen", "gemma": "gemma", "phi4": "phi3", @@ -346,8 +354,6 @@ def _reconstruct_weight( "mistral": "mistral", "granite": "granite", "gpt2": "gpt2", - # Easy to add more mappings - # "phi": "phi", } @@ -767,11 +773,30 @@ def _load_model_memory_efficient( if dist.get_rank() == 0: with torch.no_grad(): log_rank_0(f"📥 Loading base model to CPU in {load_dtype}...") - base_model = base_model_class.from_pretrained( - pretrained_model_name_or_path, - *model_args, - **final_base_kwargs, - ) + + # Check if this is a VLM wrapping a CausalLM text backbone + _is_vlm = False + try: + from transformers import AutoConfig + + _pre_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) + _is_vlm = is_vlm_with_causal_lm(_pre_config) + except (OSError, ValueError): + # Config loading can fail for local-only or mock paths + pass + + if _is_vlm: + log_rank_0("🔄 VLM detected – extracting CausalLM text backbone for OSFT") + # Filter out pretrained_model_name_or_path to avoid duplicate + # argument since it's passed positionally to extract_causal_lm_from_vlm + vlm_kwargs = {k: v for k, v in final_base_kwargs.items() if k != "pretrained_model_name_or_path"} + base_model = extract_causal_lm_from_vlm(pretrained_model_name_or_path, vlm_kwargs) + else: + base_model = base_model_class.from_pretrained( + pretrained_model_name_or_path, + *model_args, + **final_base_kwargs, + ) align_fn = osft_class_kwargs.get("lazy_init_tokenizer_align_fn") if align_fn: @@ -888,6 +913,12 @@ def create_osft_model_class(base_cls) -> type[OSFTModel]: class ModelWithOSFT(base_cls): osft_config: dict[str, int] + @classmethod + def _can_set_experts_implementation(cls): + if hasattr(base_cls, "_can_set_experts_implementation"): + return base_cls._can_set_experts_implementation() + return False + def __init__( self, config, diff --git a/src/mini_trainer/sampler.py b/src/mini_trainer/sampler.py index 1f7412e..d2702f6 100644 --- a/src/mini_trainer/sampler.py +++ b/src/mini_trainer/sampler.py @@ -29,10 +29,7 @@ import numpy as np import torch import torch.distributed as dist -from datasets import ( - Dataset as HFDataset, - load_dataset, -) +from datasets import Dataset as HFDataset, load_dataset from deprecated import deprecated from torch.utils.data import DataLoader, Dataset, Sampler, SequentialSampler diff --git a/src/mini_trainer/setup_model_for_training.py b/src/mini_trainer/setup_model_for_training.py index 55280f4..1456128 100644 --- a/src/mini_trainer/setup_model_for_training.py +++ b/src/mini_trainer/setup_model_for_training.py @@ -7,13 +7,8 @@ import torch import torch.distributed as dist -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper as ptd_checkpoint_wrapper, -) -from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - set_model_state_dict, -) +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper +from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard from transformers import AutoConfig, AutoTokenizer, Mxfp4Config @@ -26,16 +21,15 @@ set_fsdp2_lazy_init_mode, ) from mini_trainer.gpt_oss_utils import freeze_router_params, is_gpt_oss_model -from mini_trainer.osft_utils import ( - OSFTModel, - _build_osft_kwargs, - _set_osft_dtypes, - create_osft_model_class, -) -from mini_trainer.utils import ( - get_model_class_from_config, - log_rank_0, - patch_target_module, +from mini_trainer.osft_utils import OSFTModel, _build_osft_kwargs, _set_osft_dtypes, create_osft_model_class +from mini_trainer.utils import get_model_class_from_config, log_rank_0, patch_target_module +from mini_trainer.vlm_utils import ( + extract_causal_lm_from_vlm, + has_timm_vision_tower, + is_vlm_for_direct_loading, + is_vlm_with_causal_lm, + load_vlm_for_text_training, + needs_sdpa, ) @@ -66,10 +60,7 @@ def _apply_liger_kernels_if_requested(use_liger_kernels, model_config, base_mode return try: - from liger_kernel.transformers.monkey_patch import ( - MODEL_TYPE_TO_APPLY_LIGER_FN, - _apply_liger_kernel, - ) + from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN, _apply_liger_kernel except ImportError as e: raise ImportError( "Tried to use liger kernels for OSFT, but they are not installed. " @@ -438,8 +429,17 @@ def wrap_fsdp2(model: torch.nn.Module) -> torch.nn.Module: print(f"WARNING: Failed to disable HuggingFace cache for model {model.__class__.__name__}: {e}") # Find the transformer block container - # Support common architectures: Llama (model.layers), GPT-2 (transformer.h) - if hasattr(model, "model") and hasattr(model.model, "layers"): + # Support common architectures: + # - VLM direct load (model.model.language_model.layers) e.g. Qwen3-VL + # - Llama-style (model.model.layers) + # - GPT-2-style (transformer.h) + if ( + hasattr(model, "model") + and hasattr(model.model, "language_model") + and hasattr(model.model.language_model, "layers") + ): + layers = model.model.language_model.layers + elif hasattr(model, "model") and hasattr(model.model, "layers"): layers = model.model.layers elif hasattr(model, "transformer") and hasattr(model.transformer, "h"): layers = model.transformer.h @@ -449,11 +449,19 @@ def wrap_fsdp2(model: torch.nn.Module) -> torch.nn.Module: "This likely means we need to update the code to support this model." ) - # Apply activation checkpointing to each block - log_rank_0(f"🔄 [Phase 2] Applying activation checkpointing to {len(layers)} blocks") - for idx, block in enumerate(layers): - # preserve_rng_state needs to be true so that the backward pass can be accurate - layers[idx] = ptd_checkpoint_wrapper(block, preserve_rng_state=True) + # Apply activation checkpointing to each block. + # VLM models (detected by language_model path) may have non-deterministic + # tensor counts during reentrant recomputation (e.g., M-RoPE), so we skip + # activation checkpointing for them. This uses more memory but avoids + # CheckpointError from tensor count mismatches. + is_vlm_direct = hasattr(model, "model") and hasattr(model.model, "language_model") + if is_vlm_direct: + log_rank_0(f"🔄 [Phase 2] Skipping activation checkpointing for VLM ({len(layers)} blocks)") + else: + log_rank_0(f"🔄 [Phase 2] Applying activation checkpointing to {len(layers)} blocks") + for idx, block in enumerate(layers): + # preserve_rng_state needs to be true so that the backward pass can be accurate + layers[idx] = ptd_checkpoint_wrapper(block, preserve_rng_state=True) # Build 1D device mesh over all ranks world_size = dist.get_world_size() @@ -577,12 +585,21 @@ def finalize_model_initialization(model: torch.nn.Module, context: ModelInitiali return model +def _get_text_config(model): + """Get the text-relevant config, falling back to text_config for VLMs.""" + config = model.config + if not hasattr(config, "vocab_size") and hasattr(config, "text_config"): + return config.text_config + return config + + def align_model_and_tokenizer(model, tokenizer): """ Aligns the model's vocabulary and special tokens with the tokenizer. """ - if len(tokenizer) > model.config.vocab_size: - print(f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size") + text_config = _get_text_config(model) + if len(tokenizer) > text_config.vocab_size: + print(f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {text_config.vocab_size} vocab size") model.resize_token_embeddings( int(8 * math.ceil(len(tokenizer) / 8.0)) ) # make the vocab size multiple of 8 for sharding the embedding layer. @@ -603,8 +620,8 @@ def align_model_and_tokenizer(model, tokenizer): "Cannot proceed with training - please configure the tokenizer properly." ) - # Step 2: Sync all special tokens from tokenizer to model.config - # This ensures model.config always reflects tokenizer's special tokens + # Step 2: Sync all special tokens from tokenizer to text_config + # This ensures the config always reflects tokenizer's special tokens special_tokens = { "pad": ("pad_token_id", "Syncing model pad token id"), "bos": ("bos_token_id", "Syncing model bos token id"), @@ -612,13 +629,13 @@ def align_model_and_tokenizer(model, tokenizer): } for token_type, (token_attr, message) in special_tokens.items(): - model_token = getattr(model.config, token_attr) + model_token = getattr(text_config, token_attr, None) tokenizer_token = getattr(tokenizer, token_attr) - # Always sync tokenizer -> model.config when tokenizer has a valid value + # Always sync tokenizer -> config when tokenizer has a valid value if tokenizer_token is not None and model_token != tokenizer_token: log_rank_0(f"{message}: {model_token} -> {tokenizer_token}") - setattr(model.config, token_attr, tokenizer_token) + setattr(text_config, token_attr, tokenizer_token) return model @@ -674,7 +691,7 @@ def get_model_save_dtype(save_dtype: str | torch.dtype | None, model_name_or_pat return original_dtype # by now we know that we are going to use a custom data type, so we just validate - if not isinstance(save_dtype, (str, torch.dtype)): + if not isinstance(save_dtype, str | torch.dtype): raise ValueError(f"error: could not recognize '{save_dtype}' as a supported dtype for saving model checkpoints") # convert dtype to a str @@ -817,12 +834,24 @@ def setup_sft_model_distributed( state_dict = None buffer_dict = None + # Check if this is a VLM wrapping a CausalLM text backbone, or a direct VLM + _model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) + is_vlm = is_vlm_with_causal_lm(_model_config) + is_direct_vlm = is_vlm_for_direct_loading(_model_config) + if dist.get_rank() == 0: log_rank_0("rank 0: loading model to CPU") try: with torch.no_grad(): - # Default load targets CPU when no device_map or accelerate is present - cpu_model = ModelClass.from_pretrained(**base_model_args) + if is_vlm: + # VLM model: load full VLM and extract CausalLM text backbone + cpu_model = extract_causal_lm_from_vlm(model_name_or_path, base_model_args) + elif is_direct_vlm: + # VLM with no CausalLM class: load directly for text-only training + cpu_model = load_vlm_for_text_training(model_name_or_path, base_model_args) + else: + # Standard CausalLM: load directly + cpu_model = ModelClass.from_pretrained(**base_model_args) cpu_model = align_model_and_tokenizer(cpu_model, tokenizer) config = cpu_model.config state_dict = cpu_model.state_dict() @@ -845,8 +874,14 @@ def setup_sft_model_distributed( # All ranks: Create model on meta device log_rank_0("creating model on meta device") - with torch.device("meta"): - model = ModelClass.from_config(config) + if is_direct_vlm: + from transformers import AutoModelForImageTextToText + + with torch.device("meta"): + model = AutoModelForImageTextToText.from_config(config) + else: + with torch.device("meta"): + model = ModelClass.from_config(config) # Align model with tokenizer model = align_model_and_tokenizer(model, tokenizer) @@ -882,6 +917,28 @@ def setup_model( model_config = AutoConfig.from_pretrained(model_name_or_path) is_gpt_oss = is_gpt_oss_model(model_config) + # Pre-populate the transformers Hub kernel cache with locally installed + # mamba_ssm and causal_conv1d packages. The Hub kernel versions may be + # compiled against a different PyTorch/CUDA build, causing runtime errors. + # Using local packages ensures ABI compatibility. + if getattr(model_config, "model_type", None) == "granitemoehybrid": + try: + import causal_conv1d + import mamba_ssm + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + from transformers.integrations.hub_kernels import _KERNEL_MODULE_MAPPING + + mamba_ssm.selective_state_update = selective_state_update + mamba_ssm.mamba_chunk_scan_combined = mamba_chunk_scan_combined + mamba_ssm.mamba_split_conv1d_scan_combined = mamba_split_conv1d_scan_combined + + _KERNEL_MODULE_MAPPING["causal-conv1d"] = causal_conv1d + _KERNEL_MODULE_MAPPING["mamba-ssm"] = mamba_ssm + log_rank_0("Using local mamba_ssm/causal_conv1d instead of Hub kernels") + except (ImportError, AttributeError) as e: + log_rank_0(f"Could not patch mamba kernels ({e}); GraniteMoeHybrid may use Hub kernels") + # Set up quantization config for GPT-OSS models if is_gpt_oss: try: @@ -896,21 +953,52 @@ def setup_model( except ImportError: log_rank_0("⚠️ GPT-OSS model detected but Mxfp4Config not available - using default config") - # Check if flash_attn is available and set appropriate attention implementation - try: - import flash_attn as _ # noqa: F401 + # Check if model requires SDPA instead of Flash Attention 2. + # This covers M-RoPE models (3D position_ids) and models with timm vision + # towers (TimmWrapperModel rejects flash_attention_2). + _needs_sdpa = needs_sdpa(model_config) - if is_gpt_oss: - base_model_args["attn_implementation"] = "kernels-community/vllm-flash-attn3" - log_rank_0("Set attention implementation to vllm-flash-attn3 for GPT-OSS") - else: - base_model_args["attn_implementation"] = "flash_attention_2" + # Handle models that need SDPA (doesn't require flash_attn) + if _needs_sdpa: + base_model_args["attn_implementation"] = "sdpa" + log_rank_0(f"Using SDPA for {model_name_or_path} (model incompatible with Flash Attention 2)") + else: + # Check if flash_attn is available for non-SDPA models + try: + import flash_attn as _ + + if is_gpt_oss: + # vllm-flash-attn3 requires Hopper (SM 9.0+) GPUs; + # GPT-OSS only supports flash-attn3 or eager + major, _ = torch.cuda.get_device_capability(0) + if major >= 9: + base_model_args["attn_implementation"] = "kernels-community/vllm-flash-attn3" + log_rank_0("Set attention implementation to vllm-flash-attn3 for GPT-OSS") + else: + base_model_args["attn_implementation"] = "eager" + log_rank_0( + f"GPT-OSS: flash-attn3 requires Hopper (SM 9.0+) GPUs, " + f"but found SM {major}.x. Using eager attention instead." + ) + else: + base_model_args["attn_implementation"] = "flash_attention_2" - except ImportError as e: - if os.environ.get("TESTING", "false").lower() == "true": - base_model_args["attn_implementation"] = "sdpa" - else: - raise e + except ImportError as e: + if os.environ.get("TESTING", "false").lower() == "true": + base_model_args["attn_implementation"] = "sdpa" + else: + raise e + + # For models with timm vision towers: set vision config to eager + # while keeping the text model's attention implementation. + # timm's TimmWrapperModel rejects both FA2 and SDPA. + if has_timm_vision_tower(model_config): + attn_impl = base_model_args.get("attn_implementation", "flash_attention_2") + base_model_args["attn_implementation"] = { + "text_config": attn_impl, + "vision_config": "eager", + } + log_rank_0(f"Model has timm vision tower — using eager attention for vision, {attn_impl} for text model.") tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) @@ -961,7 +1049,12 @@ def load_standard_model(): ) else: # non-distributed path: direct loading - model = ModelClass.from_pretrained(**base_model_args) + if is_vlm_with_causal_lm(model_config): + model = extract_causal_lm_from_vlm(model_name_or_path, base_model_args) + elif is_vlm_for_direct_loading(model_config): + model = load_vlm_for_text_training(model_name_or_path, base_model_args) + else: + model = ModelClass.from_pretrained(**base_model_args) return align_model_and_tokenizer(model, tokenizer) def load_osft_model(): @@ -1070,6 +1163,7 @@ def load_osft_model(): # List of supported architectures if class_name not in [ "MistralForCausalLM", + "Ministral3ForCausalLM", "GPTDolomiteForCausalLM", "LlamaForCausalLM", "Starcoder2ForCausalLM", @@ -1080,6 +1174,9 @@ def load_osft_model(): "Qwen2ForCausalLM", "Phi3ForCausalLM", # covers phi3 and phi4 "Qwen3ForCausalLM", + "Qwen3_5ForCausalLM", + "Qwen3VLForConditionalGeneration", # direct VLM loading (no CausalLM class) + "Gemma3nForConditionalGeneration", # dual-registered VLM loaded as CausalLM ]: log_rank_0( f"\033[38;2;255;255;0mWarning: Model class name: {class_name} is not in the list of supported models.\033[0m", diff --git a/src/mini_trainer/train.py b/src/mini_trainer/train.py index aea386a..621f636 100644 --- a/src/mini_trainer/train.py +++ b/src/mini_trainer/train.py @@ -8,9 +8,7 @@ import torch import torch.distributed as dist -from torch.distributed._tensor.api import ( - DTensor as _DTensor, -) # works if DTensor is available +from torch.distributed._tensor.api import DTensor as _DTensor # works if DTensor is available from tqdm import tqdm from typer import Option, Typer @@ -119,10 +117,7 @@ def save_model( from safetensors.torch import save_file from transformers import AutoTokenizer - from mini_trainer.gpt_oss_utils import ( - convert_dequantized_to_quantized_format_correct, - is_gpt_oss_model, - ) + from mini_trainer.gpt_oss_utils import convert_dequantized_to_quantized_format_correct, is_gpt_oss_model # Only on rank 0 suffix_text = f" ({suffix})" if suffix else "" @@ -145,10 +140,7 @@ def save_model( # processes weights on the GPU device in batches before de-allocating the memory being consumed # Users may also face issues here if they lack the CPU memory required to store the original # FP32 state dict on CPU. - from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - get_model_state_dict, - ) + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict state_dict = get_model_state_dict( fsdp_model, @@ -1247,7 +1239,7 @@ def main( batch_size=batch_size, max_tokens_per_gpu=max_tokens_per_gpu, seed=seed, - pad_token_id=model.config.pad_token_id, + pad_token_id=getattr(getattr(model.config, "text_config", model.config), "pad_token_id", None), validation_split=validation_split, pretraining_config=pretraining_config, ) diff --git a/src/mini_trainer/utils.py b/src/mini_trainer/utils.py index 2cd931e..af883b1 100644 --- a/src/mini_trainer/utils.py +++ b/src/mini_trainer/utils.py @@ -139,16 +139,35 @@ def init_distributed_environment(): def get_model_class_from_config(model_path): - """Get the actual model class (not just the name) from a pretrained path.""" + """Get the actual model class (not just the name) from a pretrained path. + + Note: vlm_utils.is_vlm_with_causal_lm() handles the broader VLM detection + (deciding *whether* to extract). This function resolves the model CLASS + regardless of VLM status, falling back to text_config when needed. + """ # get the model class from config # TODO: make the `trust_remote_code` setting configurable somehow config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) mapping = MODEL_FOR_CAUSAL_LM_MAPPING config_class = config.__class__ - if config_class not in mapping: - raise ValueError(f"Model class {config_class} not found in mapping {mapping}") - return mapping[config_class] + if config_class in mapping: + return mapping[config_class] + + # Fallback: for VLM models that wrap a CausalLM text backbone + # (e.g. Mistral3Config wrapping Ministral3Config), check text_config + text_config = getattr(config, "text_config", None) + if text_config is not None and text_config.__class__ in mapping: + return mapping[text_config.__class__] + + # Fallback: for VLMs with no CausalLM class at all (e.g. Qwen3-VL-2B), + # check the ImageTextToText mapping so they can be loaded directly. + from transformers.models.auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING + + if config_class in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING: + return MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING[config_class] + + raise ValueError(f"Model class {config_class} not found in mapping {mapping}") def destroy_distributed_environment(): diff --git a/src/mini_trainer/vlm_utils.py b/src/mini_trainer/vlm_utils.py new file mode 100644 index 0000000..916edab --- /dev/null +++ b/src/mini_trainer/vlm_utils.py @@ -0,0 +1,286 @@ +"""Utilities for detecting and extracting CausalLM text backbones from VLM models. + +Vision-Language Models (VLMs) like Mistral3ForConditionalGeneration wrap a +CausalLM text backbone (e.g. Ministral3ForCausalLM). This module provides +helpers to detect that wrapping and extract the text backbone so mini-trainer +can treat it as a standard CausalLM for SFT / OSFT training. + +For VLMs that have NO standalone CausalLM class (e.g. Qwen3-VL-2B), this +module also provides helpers to load the VLM directly for text-only training. +""" + +import torch +import torch.nn as nn +from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING + +from mini_trainer.utils import log_rank_0 + + +def is_vlm_with_causal_lm(config) -> bool: + """Check if a model config is a VLM wrapping a CausalLM text backbone. + + Returns True only when the top-level config is NOT in the CausalLM + mapping but its nested ``text_config`` IS. Models that are directly + registered as CausalLM (even if they also have a text_config) return + False. + + Args: + config: An already-loaded HuggingFace model config object. + + Returns: + True if the model is a VLM wrapping a CausalLM text backbone. + """ + if config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING: + return False + text_config = getattr(config, "text_config", None) + return text_config is not None and text_config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING + + +def is_vlm_for_direct_loading(config) -> bool: + """Check if a model config is a VLM that should be loaded directly for text-only training. + + Returns True when the model is NOT in the CausalLM mapping, has no + extractable CausalLM text backbone (via ``text_config``), but IS + registered in the ImageTextToText mapping. This covers models like + Qwen3-VL-2B that have no standalone CausalLM class at all. + + Args: + config: An already-loaded HuggingFace model config object. + + Returns: + True if the model is a VLM that should be loaded directly. + """ + from transformers.models.auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING + + # Already a CausalLM — load normally + if config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING: + return False + + # Has an extractable CausalLM text backbone — use extraction path + text_config = getattr(config, "text_config", None) + if text_config is not None and text_config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING: + return False + + # Is a VLM with no CausalLM mapping at all — load directly + return config.__class__ in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING + + +def load_vlm_for_text_training(model_path: str, load_kwargs: dict) -> nn.Module: + """Load a VLM directly for text-only training. + + Used for VLM models that have no standalone CausalLM class (detected + by :func:`is_vlm_for_direct_loading`). The full VLM is loaded via + ``AutoModelForImageTextToText`` and used as-is for text-only forward + passes (input_ids + labels). + + Note: The layer structure for these models is typically + ``model.model.language_model.layers`` rather than ``model.model.layers``. + + Args: + model_path: HuggingFace model name or local path. + load_kwargs: Keyword arguments forwarded to ``from_pretrained``. + + Returns: + The loaded VLM model ready for text-only training. + """ + from transformers import AutoModelForImageTextToText + + log_rank_0("🔄 VLM detected (no CausalLM class) – loading directly for text-only training") + + # Filter out None quantization_config to avoid interfering with + # the model's built-in quantization handling. + # Also filter out pretrained_model_name_or_path since model_path is passed positionally. + filtered_kwargs = { + k: v + for k, v in load_kwargs.items() + if k != "pretrained_model_name_or_path" and not (k == "quantization_config" and v is None) + } + model = AutoModelForImageTextToText.from_pretrained(model_path, **filtered_kwargs) + + log_rank_0(f" ✅ Loaded {type(model).__name__} directly for text-only training") + return model + + +def _find_text_backbone(vlm_model: nn.Module) -> nn.Module: + """Auto-detect the text backbone inside a VLM model. + + Tries well-known attribute names first (``language_model``, + ``text_model``, ``llm``), then falls back to searching + ``named_children`` for class names containing ``ForCausalLM`` or + ``TextModel``. + + Args: + vlm_model: The loaded VLM model. + + Returns: + The text backbone module. + + Raises: + ValueError: If no text backbone can be found. + """ + inner = vlm_model.model if hasattr(vlm_model, "model") else vlm_model + + # Well-known attribute names + for attr_name in ("language_model", "text_model", "llm"): + if hasattr(inner, attr_name): + return getattr(inner, attr_name) + + # Fallback: search named children for common class-name patterns + for name, child in inner.named_children(): + cls_name = child.__class__.__name__ + if "ForCausalLM" in cls_name or "TextModel" in cls_name: + return child + + available = [name for name, _ in inner.named_children()] + raise ValueError( + f"Cannot find text backbone in {type(vlm_model).__name__}. Available sub-modules on inner model: {available}" + ) + + +def extract_causal_lm_from_vlm(model_path: str, load_kwargs: dict) -> nn.Module: + """Load a VLM and extract the CausalLM text backbone. + + Loads the full VLM via ``AutoModelForImageTextToText``, auto-detects + the text backbone using :func:`_find_text_backbone`, then creates a + standalone CausalLM model by transferring weights. + + Args: + model_path: HuggingFace model name or local path. + load_kwargs: Keyword arguments forwarded to ``from_pretrained``. + + Returns: + A standalone CausalLM model with the VLM's text weights. + """ + from transformers import AutoConfig, AutoModelForImageTextToText + + log_rank_0("🔄 VLM detected – loading full VLM to extract CausalLM text backbone") + + # Filter out None quantization_config to avoid interfering with + # the model's built-in quantization handling (e.g. FP8 auto-dequant). + # Also filter out pretrained_model_name_or_path since model_path is passed positionally. + vlm_kwargs = { + k: v + for k, v in load_kwargs.items() + if k != "pretrained_model_name_or_path" and not (k == "quantization_config" and v is None) + } + vlm = AutoModelForImageTextToText.from_pretrained(model_path, **vlm_kwargs) + + # Auto-detect text backbone + backbone = _find_text_backbone(vlm) + + # Resolve text_config and create standalone CausalLM + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + text_config = config.text_config + causal_lm_class = MODEL_FOR_CAUSAL_LM_MAPPING[text_config.__class__] + + log_rank_0(f" Extracting {causal_lm_class.__name__} from {type(vlm).__name__}") + text_model = causal_lm_class(text_config) + + # Transfer backbone weights + text_model.model = backbone + + # Transfer lm_head + if hasattr(vlm, "lm_head"): + text_model.lm_head = vlm.lm_head + else: + raise ValueError(f"Cannot extract lm_head from {type(vlm).__name__}") + + del vlm + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + log_rank_0(f" ✅ Extracted {causal_lm_class.__name__} successfully") + return text_model + + +def has_mrope(config) -> bool: + """Check if a model config uses M-RoPE (multimodal rotary position embeddings). + + Inspects both the top-level config and its ``text_config`` (if present) + for ``rope_scaling`` or ``rope_parameters`` dicts containing the + ``mrope_section`` key. + + Args: + config: An already-loaded HuggingFace model config object. + + Returns: + True if M-RoPE is detected. + """ + for cfg in (config, getattr(config, "text_config", None)): + if cfg is None: + continue + for attr in ("rope_scaling", "rope_parameters"): + rope_obj = getattr(cfg, attr, None) + if rope_obj is None: + continue + # Handle both dict and RopeParameters objects + if isinstance(rope_obj, dict) and "mrope_section" in rope_obj: + return True + if not isinstance(rope_obj, dict) and hasattr(rope_obj, "mrope_section"): + return True + return False + + +def needs_sdpa(config) -> bool: + """Check if a model requires SDPA instead of Flash Attention 2. + + Returns True when the model has characteristics incompatible with + Flash Attention 2: + - M-RoPE (multimodal rotary position embeddings) producing 3D position_ids + - A timm-based vision tower (TimmWrapperModel rejects flash_attention_2) + + Args: + config: An already-loaded HuggingFace model config object. + + Returns: + True if the model should use SDPA attention. + """ + if has_mrope(config): + return True + + vision_config = getattr(config, "vision_config", None) + if vision_config is not None: + model_type = getattr(vision_config, "model_type", "") + if model_type in ("timm_wrapper", "gemma3n_vision"): + return True + try: + from transformers.models.auto import MODEL_MAPPING + + if vision_config.__class__ in MODEL_MAPPING: + vision_cls = MODEL_MAPPING[vision_config.__class__] + if "Timm" in vision_cls.__name__: + return True + except Exception: + pass + + return False + + +def has_timm_vision_tower(config) -> bool: + """Check if a model config has a timm-based vision tower. + + timm vision towers only support ``eager`` attention. The vision config + must be patched to use eager while the text model can use FA2/SDPA. + + Args: + config: An already-loaded HuggingFace model config object. + + Returns: + True if the model has a timm-based vision tower. + """ + vision_config = getattr(config, "vision_config", None) + if vision_config is None: + return False + model_type = getattr(vision_config, "model_type", "") + if model_type in ("timm_wrapper", "gemma3n_vision"): + return True + try: + from transformers.models.auto import MODEL_MAPPING + + if vision_config.__class__ in MODEL_MAPPING: + vision_cls = MODEL_MAPPING[vision_config.__class__] + if "Timm" in vision_cls.__name__: + return True + except Exception: + pass + return False diff --git a/tests/gpu_tests/test_distributed_utils.py b/tests/gpu_tests/test_distributed_utils.py index 2a26ccf..21acc8a 100644 --- a/tests/gpu_tests/test_distributed_utils.py +++ b/tests/gpu_tests/test_distributed_utils.py @@ -10,10 +10,7 @@ import pytest -from mini_trainer.utils import ( - check_distributed_is_synchronized, - init_distributed_environment, -) +from mini_trainer.utils import check_distributed_is_synchronized, init_distributed_environment @pytest.mark.gpu diff --git a/tests/gpu_tests/test_mixed_precision.py b/tests/gpu_tests/test_mixed_precision.py index f5e3ac9..a74052f 100644 --- a/tests/gpu_tests/test_mixed_precision.py +++ b/tests/gpu_tests/test_mixed_precision.py @@ -85,9 +85,7 @@ def test_fsdp2_mixed_precision_dtypes(self, tmp_path, single_gpu_device): tokenizer.save_pretrained(model_path) # Patch loss function for none reduction - from mini_trainer.none_reduction_losses import ( - hf_fixed_cross_entropy_none_reduction, - ) + from mini_trainer.none_reduction_losses import hf_fixed_cross_entropy_none_reduction patch_target_module( "transformers.loss.loss_utils.fixed_cross_entropy", diff --git a/tests/test_integration_small_models.py b/tests/test_integration_small_models.py index 117ecaf..b29e03e 100644 --- a/tests/test_integration_small_models.py +++ b/tests/test_integration_small_models.py @@ -27,14 +27,8 @@ Qwen2ForCausalLM, ) -from mini_trainer.osft_utils import ( - auto_generate_target_osft_config, - create_osft_model_class, -) -from mini_trainer.setup_model_for_training import ( - align_model_and_tokenizer, - setup_training_components, -) +from mini_trainer.osft_utils import auto_generate_target_osft_config, create_osft_model_class +from mini_trainer.setup_model_for_training import align_model_and_tokenizer, setup_training_components # TODO: add tests to validate our codebase works with these models diff --git a/tests/test_model_initialization.py b/tests/test_model_initialization.py index 0dc1c1b..7d4558a 100644 --- a/tests/test_model_initialization.py +++ b/tests/test_model_initialization.py @@ -109,9 +109,9 @@ def mock_model(self): model.config = MagicMock() model.config.use_cache = True - # Mock transformer layers + # Mock transformer layers (standard CausalLM structure: model.model.layers) layers = [MagicMock() for _ in range(4)] - model.model = MagicMock() + model.model = MagicMock(spec=[]) # spec=[] prevents auto-creating attributes model.model.layers = layers return model diff --git a/tests/test_osft_dtype_functionality.py b/tests/test_osft_dtype_functionality.py index 153d545..17a54a2 100644 --- a/tests/test_osft_dtype_functionality.py +++ b/tests/test_osft_dtype_functionality.py @@ -7,11 +7,7 @@ import torch import torch.nn as nn -from mini_trainer.osft_utils import ( - create_osft_model_class, - create_svd_dict, - reconstruct_weight_matrix, -) +from mini_trainer.osft_utils import create_osft_model_class, create_svd_dict, reconstruct_weight_matrix from mini_trainer.setup_model_for_training import setup_model diff --git a/tests/test_training_components.py b/tests/test_training_components.py index f62c00e..a55149c 100644 --- a/tests/test_training_components.py +++ b/tests/test_training_components.py @@ -20,9 +20,7 @@ from mini_trainer.batch_metrics import BatchMetrics from mini_trainer.train import save_model, take_gradient_step -from mini_trainer.utils import ( - patch_target_module, -) +from mini_trainer.utils import patch_target_module class TestTakeGradientStep: