Skip to content

Commit 3c1e9d7

Browse files
RobotSailclaude
andauthored
add support for qwen3.5 vl model (#70)
* add support for qwen3.5 vl model * enable detection of VLM models and allow using non-Hopper GPUs for GPT-OSS * fix gpt-oss-20b initialization * add support for more vlms * adds general vlm support * support gemma3n * address coderabbit review comments - Reorder MODEL_NAME_MAPPINGS for correct substring matching - Filter pretrained_model_name_or_path from VLM load kwargs - Move SDPA decision outside flash_attn import block Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix test: prevent MagicMock auto-creating VLM attributes Use spec=[] on mock model.model to prevent hasattr from falsely detecting language_model attribute in wrap_fsdp2 tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * address remaining review comments - osft_utils.py: filter pretrained_model_name_or_path from VLM kwargs in OSFT path to prevent duplicate argument error - osft_utils.py: add hasattr guard for _can_set_experts_implementation on non-MoE base classes - vlm_utils.py: handle RopeParameters objects (not just dicts) in mrope detection via hasattr fallback - Fix isort ordering Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix pre-existing test broken by VLM detection Wrap AutoConfig.from_pretrained in try/except in _load_model_memory_efficient so mock/dummy model paths don't crash the VLM detection check. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix mamba kernel patching: broaden exception handling, fix comment - Catch AttributeError in addition to ImportError to prevent partial patching of _KERNEL_MODULE_MAPPING - Update comment to accurately describe the compatibility concern (PyTorch/CUDA ABI mismatch, not C API incompatibility) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix format and lint for all changed files - Reformat 8 test/source files to match CI ruff version - Fix UP038: use X | Y in isinstance call Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix VLM OSFT support for direct-loaded models - Remove unnecessary OSFT guard for direct VLMs (patterns match fine) - Add _get_text_config() helper for VLM config fallback in align_model_and_tokenizer (vocab_size, pad/bos/eos_token_id) - Fix model.config.pad_token_id access in train.py for VLM configs - Skip activation checkpointing for direct VLM models (M-RoPE layers produce non-deterministic tensor counts during reentrant recomputation) - Use dynamic ports (_get_free_port) in model_validation.py to prevent port conflicts between sequential tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix ruff format for CI version (0.15.5) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4d6dc87 commit 3c1e9d7

12 files changed

+520
-115
lines changed

src/mini_trainer/osft_utils.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,14 @@
1010
import torch
1111
import torch.distributed as dist
1212
import torch.nn as nn
13-
from torch.distributed.checkpoint.state_dict import (
14-
StateDictOptions,
15-
set_model_state_dict,
16-
)
13+
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
1714
from tqdm import tqdm
1815
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM
1916

20-
from mini_trainer.fsdp2_lazy_init import (
21-
FSDP2_LAZY_INIT_OSFT,
22-
get_fsdp2_lazy_init_mode,
23-
set_fsdp2_lazy_init_mode,
24-
)
17+
from mini_trainer.fsdp2_lazy_init import FSDP2_LAZY_INIT_OSFT, get_fsdp2_lazy_init_mode, set_fsdp2_lazy_init_mode
2518
from mini_trainer.gpt_oss_utils import is_gpt_oss_model
2619
from mini_trainer.utils import get_control_process_group, log_rank_0
20+
from mini_trainer.vlm_utils import extract_causal_lm_from_vlm, is_vlm_with_causal_lm
2721

2822
# Memory optimization constants
2923
OSFT_CACHE_CLEAR_INTERVAL = int(
@@ -315,6 +309,18 @@ def _reconstruct_weight(
315309
# "experts.down_proj",
316310
]
317311
},
312+
"qwen3_5": {
313+
"patterns": [
314+
"self_attn.q_proj",
315+
"self_attn.k_proj",
316+
"self_attn.v_proj",
317+
"self_attn.o_proj",
318+
"linear_attn.out_proj",
319+
"mlp.gate_proj",
320+
"mlp.down_proj",
321+
"mlp.up_proj",
322+
]
323+
},
318324
"default": {
319325
"patterns": [
320326
"self_attn.q_proj",
@@ -337,6 +343,8 @@ def _reconstruct_weight(
337343
"gptneo": "gpt-neo", # Handle both "gpt-neo" and "gptneo" variants
338344
"gpt-oss": "gpt-oss",
339345
"opt": "opt",
346+
"qwen3_5": "qwen3_5", # specific BEFORE generic
347+
"qwen3.5": "qwen3_5", # specific BEFORE generic
340348
"qwen": "qwen",
341349
"gemma": "gemma",
342350
"phi4": "phi3",
@@ -346,8 +354,6 @@ def _reconstruct_weight(
346354
"mistral": "mistral",
347355
"granite": "granite",
348356
"gpt2": "gpt2",
349-
# Easy to add more mappings
350-
# "phi": "phi",
351357
}
352358

353359

@@ -767,11 +773,30 @@ def _load_model_memory_efficient(
767773
if dist.get_rank() == 0:
768774
with torch.no_grad():
769775
log_rank_0(f"📥 Loading base model to CPU in {load_dtype}...")
770-
base_model = base_model_class.from_pretrained(
771-
pretrained_model_name_or_path,
772-
*model_args,
773-
**final_base_kwargs,
774-
)
776+
777+
# Check if this is a VLM wrapping a CausalLM text backbone
778+
_is_vlm = False
779+
try:
780+
from transformers import AutoConfig
781+
782+
_pre_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
783+
_is_vlm = is_vlm_with_causal_lm(_pre_config)
784+
except (OSError, ValueError):
785+
# Config loading can fail for local-only or mock paths
786+
pass
787+
788+
if _is_vlm:
789+
log_rank_0("🔄 VLM detected – extracting CausalLM text backbone for OSFT")
790+
# Filter out pretrained_model_name_or_path to avoid duplicate
791+
# argument since it's passed positionally to extract_causal_lm_from_vlm
792+
vlm_kwargs = {k: v for k, v in final_base_kwargs.items() if k != "pretrained_model_name_or_path"}
793+
base_model = extract_causal_lm_from_vlm(pretrained_model_name_or_path, vlm_kwargs)
794+
else:
795+
base_model = base_model_class.from_pretrained(
796+
pretrained_model_name_or_path,
797+
*model_args,
798+
**final_base_kwargs,
799+
)
775800

776801
align_fn = osft_class_kwargs.get("lazy_init_tokenizer_align_fn")
777802
if align_fn:
@@ -888,6 +913,12 @@ def create_osft_model_class(base_cls) -> type[OSFTModel]:
888913
class ModelWithOSFT(base_cls):
889914
osft_config: dict[str, int]
890915

916+
@classmethod
917+
def _can_set_experts_implementation(cls):
918+
if hasattr(base_cls, "_can_set_experts_implementation"):
919+
return base_cls._can_set_experts_implementation()
920+
return False
921+
891922
def __init__(
892923
self,
893924
config,

src/mini_trainer/sampler.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,7 @@
2929
import numpy as np
3030
import torch
3131
import torch.distributed as dist
32-
from datasets import (
33-
Dataset as HFDataset,
34-
load_dataset,
35-
)
32+
from datasets import Dataset as HFDataset, load_dataset
3633
from deprecated import deprecated
3734
from torch.utils.data import DataLoader, Dataset, Sampler, SequentialSampler
3835

0 commit comments

Comments
 (0)