Skip to content

Commit 4ff63d5

Browse files
committed
Add support for ShieldGemma2ForImageClassification
- Register ShieldGemma2ForImageClassification for both TEXT and MMPROJ model types - Add prefix handling for 'model.' and 'model.language_model.' prefixes in tensor names - Enable conversion of ShieldGemma models to GGUF format with vision encoder support This enables conversion of google/shieldgemma-2-4b-it and similar models for content moderation tasks with llama.cpp multimodal support.
1 parent 9b17d74 commit 4ff63d5

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

convert_hf_to_gguf.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5740,7 +5740,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
57405740
return [(self.map_tensor_name(name), data_torch)]
57415741

57425742

5743-
@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
5743+
@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration", "ShieldGemma2ForImageClassification")
57445744
class Gemma3Model(TextModel):
57455745
model_arch = gguf.MODEL_ARCH.GEMMA3
57465746
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
@@ -5778,13 +5778,17 @@ def set_gguf_parameters(self):
57785778
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
57795779
del bid # unused
57805780

5781-
if "language_model." in name:
5782-
name = name.replace("language_model.", "")
5783-
5784-
elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
5785-
or name.startswith("multimodal_projector.") or name.startswith("vision_model."):
5781+
# Skip vision tensors (check before prefix removal)
5782+
if "multi_modal_projector." in name or "vision_tower." in name \
5783+
or "multimodal_projector." in name or "vision_model." in name:
57865784
return [] # skip vision tensors
57875785

5786+
# Handle ShieldGemma2 prefix
5787+
if name.startswith("model.language_model."):
5788+
name = name.replace("model.language_model.", "")
5789+
elif "language_model." in name:
5790+
name = name.replace("language_model.", "")
5791+
57885792
# remove OOV (out-of-vocabulary) rows in token_embd
57895793
if "embed_tokens.weight" in name:
57905794
vocab = self._create_vocab_sentencepiece()
@@ -5874,7 +5878,7 @@ def set_gguf_parameters(self):
58745878
self._try_set_pooling_type()
58755879

58765880

5877-
@ModelBase.register("Gemma3ForConditionalGeneration")
5881+
@ModelBase.register("Gemma3ForConditionalGeneration", "ShieldGemma2ForImageClassification")
58785882
class Gemma3VisionModel(MmprojModel):
58795883
def set_gguf_parameters(self):
58805884
super().set_gguf_parameters()
@@ -5908,6 +5912,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
59085912
if "vision_model.head." in name:
59095913
return [] # skip redundant tensors for tinygemma3
59105914

5915+
# Handle ShieldGemma2 prefix
5916+
if name.startswith("model."):
5917+
name = name.replace("model.", "", 1)
5918+
59115919
if name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
59125920
or name.startswith("multimodal_projector.") or name.startswith("vision_model."):
59135921
# process vision tensors

0 commit comments

Comments
 (0)