Skip to content

Commit bee5dca

Browse files
Fix accessing final norm for Gemma-3 models (microsoft#1687)
### Description This PR fixes how the final norm is identified for the Gemma-3 models. It works with the latest version of Hugging Face's `transformers` (v4.55.2). ### Motivation and Context Previous versions of `transformers` would modify the class structure for the Gemma-3 models as breaking changes. Since `transformers` has [landed on a stable way](huggingface/transformers#36741) to load multi-modal models with `AutoModelForCausalLM` for now, the current approach is to identify the path to `model.model.language_model.norm` for the Gemma-3 models that are multi-modal. Gemma-3 1B's final norm is accessible at `model.model.norm` while Gemma-3 4B's final norm is accessible at `model.model.language_model.norm`. For [PEFT's](https://github.com/huggingface/peft) decoder-only models, the core model is accessible at `model.base_model.model` and the final norm is usually accessible at `model.base_model.model.model.norm`. We can read the parent-most class name to identify whether a model is from PEFT or not. One advantage with this approach is that it allows any adaptations in the path to the final norm of a Transformers model to still be found in the PEFT version of that model.
1 parent 47fb158 commit bee5dca

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

src/python/py/models/builder.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2713,31 +2713,27 @@ def make_model(self, input_path):
27132713
def has_final_norm(self, module, orig_model):
27142714
# Find where the language model is stored to check attributes. Some classes
27152715
# store the language model in a different attribute than `model.model`.
2716-
if hasattr(orig_model, "language_model"):
2717-
# Model is multimodal
2718-
# Note: This case is checked first because the `language_model` attribute and the `base_model` attribute
2719-
# exist for both multimodal models and PEFT models. However they represent different classes and their attributes
2720-
# differ.
2721-
model = orig_model.language_model
2722-
elif hasattr(orig_model, "base_model") and hasattr(orig_model.base_model, "model"):
2723-
if hasattr(orig_model.base_model.model, "model"):
2724-
# Model is from PEFT
2725-
model = orig_model.base_model.model
2726-
else:
2727-
# Model is text-based only.
2728-
model = orig_model.base_model
2716+
if orig_model.__class__.__name__.startswith("Peft"):
2717+
# Model is from PEFT
2718+
model = orig_model.base_model.model
27292719
else:
27302720
model = orig_model
27312721

2732-
# Hugging Face names
2722+
# Hugging Face names (all models loaded with AutoModelForCausalLM.from_pretrained)
2723+
#
2724+
# hf_norm: for most models
2725+
# hf_final_layernorm: for Phi-2
2726+
# hf_transformer_final_layernorm: for ChatGLM-3
2727+
# hf_language_model_norm: for Gemma-3 multimodal (4B, 12B, 27B)
27332728
hf_norm = hasattr(model, "model") and hasattr(model.model, "norm") and module == model.model.norm
27342729
hf_final_layernorm = hasattr(model, "model") and hasattr(model.model, "final_layernorm") and module == model.model.final_layernorm
27352730
hf_transformer_final_layernorm = hasattr(model, "transformer") and hasattr(model.transformer, "encoder") and hasattr(model.transformer.encoder, "final_layernorm") and module == model.transformer.encoder.final_layernorm
2731+
hf_language_model_norm = hasattr(model, "model") and hasattr(model.model, "language_model") and hasattr(model.model.language_model, "norm") and module == model.model.language_model.norm
27362732

2737-
# GGUF names
2733+
# GGUF names (all models loaded with GGUFModel.from_pretrained)
27382734
gguf_final_norm = hasattr(model, "final_norm") and module == model.final_norm
27392735

2740-
hf_names = [hf_norm, hf_final_layernorm, hf_transformer_final_layernorm]
2736+
hf_names = [hf_norm, hf_final_layernorm, hf_transformer_final_layernorm, hf_language_model_norm]
27412737
gguf_names = [gguf_final_norm]
27422738
return any(hf_names + gguf_names)
27432739

@@ -3264,7 +3260,7 @@ def make_layernorm(self, layer_id, layernorm, skip, simple, location):
32643260
super().make_layernorm(layer_id, layernorm, skip, simple, location)
32653261

32663262
def make_layer(self, layer_id, layer):
3267-
# Gemma2 decoder layer is typically defined as:
3263+
# Gemma-2 decoder layer is typically defined as:
32683264
# input_layernorm --> attention --> post_attention_layernorm --> pre_ffn_layernorm --> MLP --> post_ffn_layernorm
32693265

32703266
# Adjust LayerNorm attributes because of extra LayerNorms inserted
@@ -3713,7 +3709,7 @@ def make_layer(self, layer_id, layer):
37133709
class Gemma3Model(Gemma2Model):
37143710
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
37153711
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
3716-
self.is_local = lambda layer_id: bool((layer_id + 1) % config.sliding_window_pattern)
3712+
self.is_local = lambda layer_id: bool((layer_id + 1) % 6)
37173713
self.rope_local_theta = config.rope_local_base_freq
37183714
self.make_rotary_embedding_multi_cache()
37193715

0 commit comments

Comments
 (0)