Skip to content

Commit 75425ca

Browse files
authored
Fix Gemma 3 SFT training by detecting dual-registered VLM configs (#695)
Gemma 3 (google/gemma-3-4b-it) is dual-registered in transformers: Gemma3Config maps to both MODEL_FOR_CAUSAL_LM_MAPPING and MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING. This caused is_vlm_with_causal_lm() to return False (because the config IS in the CausalLM mapping), so the model was loaded via AutoModelForCausalLM — which resolves to the full Gemma3ForConditionalGeneration VLM class, not a text-only CausalLM. The VLM forward pass then crashed during FSDP-wrapped distributed training because the text-only SFT training loop doesn't handle the vision tower. The fix checks what class MODEL_FOR_CAUSAL_LM_MAPPING actually resolves to. If it's a ForConditionalGeneration class (a VLM), the model is treated as needing backbone extraction, same as Ministral/Mistral3 models. Tested with model_validation.py: both gemma-3-4b-it and gemma-3n-E4B-it now train to loss 0.0000 on 1000-sample overfit dataset across 8x A100s. Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com>
1 parent 0c6614a commit 75425ca

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

src/instructlab/training/vlm_utils.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515
def is_vlm_with_causal_lm(model_path: str, trust_remote_code: bool = False) -> bool:
1616
"""Check if a model is a VLM that wraps a CausalLM text backbone.
1717
18-
Returns True when the model's top-level config does NOT map to a CausalLM
19-
but its ``text_config`` does — meaning the model needs VLM extraction to
20-
obtain the trainable CausalLM sub-model.
18+
Returns True when the model needs VLM extraction to obtain the trainable
19+
CausalLM sub-model. This covers two cases:
2120
22-
Models that are dual-registered (top-level config maps directly to CausalLM)
23-
return False because ``AutoModelForCausalLM`` can load them without
24-
extraction.
21+
1. The top-level config does NOT map to CausalLM, but ``text_config`` does
22+
(e.g. Ministral-3 / Mistral3ForConditionalGeneration).
23+
2. The top-level config IS in the CausalLM mapping, but the resolved class
24+
is actually a ``ForConditionalGeneration`` VLM (e.g. Gemma 3, which is
25+
dual-registered so ``AutoModelForCausalLM`` loads the full VLM). These
26+
models still have an extractable CausalLM text backbone via
27+
``text_config``.
2528
2629
Args:
2730
model_path: HuggingFace model ID or local path.
@@ -32,11 +35,22 @@ def is_vlm_with_causal_lm(model_path: str, trust_remote_code: bool = False) -> b
3235
"""
3336
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
3437

35-
# If the top-level config maps to CausalLM, no extraction needed.
38+
text_config = getattr(config, "text_config", None)
39+
3640
if config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING:
37-
return False
41+
# The config maps to CausalLM, but check what class it actually
42+
# resolves to. Some models (e.g. Gemma 3) are dual-registered and
43+
# AutoModelForCausalLM loads a ForConditionalGeneration VLM instead
44+
# of a text-only CausalLM. Those still need extraction.
45+
resolved_cls = MODEL_FOR_CAUSAL_LM_MAPPING[config.__class__]
46+
is_actually_vlm = "ForConditionalGeneration" in resolved_cls.__name__
47+
if not is_actually_vlm:
48+
return False
49+
# It's a VLM disguised as CausalLM — fall through to check text_config
50+
if text_config is None:
51+
return False
52+
return text_config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING
3853

39-
text_config = getattr(config, "text_config", None)
4054
if text_config is None:
4155
return False
4256

0 commit comments

Comments
 (0)