Skip to content

Commit c7d6a86

Browse files
RobotSailclaude
andcommitted
fix pre-existing test broken by VLM detection
Wrap AutoConfig.from_pretrained in try/except in _load_model_memory_efficient so mock/dummy model paths don't crash the VLM detection check. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d5023bd commit c7d6a86

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

src/mini_trainer/osft_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -775,11 +775,17 @@ def _load_model_memory_efficient(
775775
log_rank_0(f"📥 Loading base model to CPU in {load_dtype}...")
776776

777777
# Check if this is a VLM wrapping a CausalLM text backbone
778-
from transformers import AutoConfig
778+
_is_vlm = False
779+
try:
780+
from transformers import AutoConfig
779781

780-
_pre_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
782+
_pre_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
783+
_is_vlm = is_vlm_with_causal_lm(_pre_config)
784+
except (OSError, ValueError):
785+
# Config loading can fail for local-only or mock paths
786+
pass
781787

782-
if is_vlm_with_causal_lm(_pre_config):
788+
if _is_vlm:
783789
log_rank_0("🔄 VLM detected – extracting CausalLM text backbone for OSFT")
784790
# Filter out pretrained_model_name_or_path to avoid duplicate
785791
# argument since it's passed positionally to extract_causal_lm_from_vlm

tests/test_osft.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,24 @@
1919

2020
import mini_trainer.osft_utils as osft_module
2121
from mini_trainer.api_train import run_training
22-
from mini_trainer.osft_utils import (MODEL_CONFIGS,
23-
_get_model_patterns_from_name,
24-
_load_model_memory_efficient,
25-
auto_generate_target_osft_config,
26-
create_osft_model_class, get_model_config,
27-
is_osft_param, optim_wrapper)
22+
from mini_trainer.osft_utils import (
23+
MODEL_CONFIGS,
24+
_get_model_patterns_from_name,
25+
_load_model_memory_efficient,
26+
auto_generate_target_osft_config,
27+
create_osft_model_class,
28+
get_model_config,
29+
is_osft_param,
30+
optim_wrapper,
31+
)
2832
from mini_trainer.setup_model_for_training import setup_model
2933
from mini_trainer.training_types import TorchrunArgs, TrainingArgs
30-
from tests.test_utils.orthogonality import (OrthogonalityTracker,
31-
check_gradient_orthogonality,
32-
check_parameter_orthogonality,
33-
compute_angle_differences)
34+
from tests.test_utils.orthogonality import (
35+
OrthogonalityTracker,
36+
check_gradient_orthogonality,
37+
check_parameter_orthogonality,
38+
compute_angle_differences,
39+
)
3440

3541

3642
class TestOSFTAPIValidation:

0 commit comments

Comments
 (0)