Commit 75425ca
authored
Fix Gemma 3 SFT training by detecting dual-registered VLM configs (#695)
Gemma 3 (google/gemma-3-4b-it) is dual-registered in transformers:
Gemma3Config maps to both MODEL_FOR_CAUSAL_LM_MAPPING and
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING. This caused is_vlm_with_causal_lm()
to return False (because the config IS in the CausalLM mapping), so the
model was loaded via AutoModelForCausalLM — which resolves to the full
Gemma3ForConditionalGeneration VLM class, not a text-only CausalLM.
The VLM forward pass then crashed during FSDP-wrapped distributed training
because the text-only SFT training loop doesn't handle the vision tower.
The fix checks what class MODEL_FOR_CAUSAL_LM_MAPPING actually resolves to.
If it's a ForConditionalGeneration class (a VLM), the model is treated as
needing backbone extraction, same as Ministral/Mistral3 models.
Tested with model_validation.py: both gemma-3-4b-it and gemma-3n-E4B-it
now train to loss 0.0000 on 1000-sample overfit dataset across 8x A100s.
Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com>1 parent 0c6614a commit 75425ca
1 file changed
+23
-9
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
18 | | - | |
19 | | - | |
20 | | - | |
| 18 | + | |
| 19 | + | |
21 | 20 | | |
22 | | - | |
23 | | - | |
24 | | - | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
25 | 28 | | |
26 | 29 | | |
27 | 30 | | |
| |||
32 | 35 | | |
33 | 36 | | |
34 | 37 | | |
35 | | - | |
| 38 | + | |
| 39 | + | |
36 | 40 | | |
37 | | - | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
38 | 53 | | |
39 | | - | |
40 | 54 | | |
41 | 55 | | |
42 | 56 | | |
| |||
0 commit comments