Skip to content

Gemma3 GQA Shape Mismatch on Torch Backend with Fused Attention #2603

@pctablet505

Description

@pctablet505

Description
When running Gemma3 (which uses Grouped Query Attention) with the torch backend, ops.dot_product_attention fails with a runtime error due to a shape mismatch between query and key/value heads.

This occurs because PyTorch's scaled_dot_product_attention (used by ops.dot_product_attention when fused opacity is enabled) does not support broadcasting for Grouped Query Attention (GQA) where the number of KV heads > 1 but different from the number of Query heads (e.g., 16 vs 8). It only supports Multi-Query Attention (KV heads = 1) or standard Multi-Head Attention (KV heads = Query heads).

Steps to Reproduce

  1. Use the torch backend.
  2. Load a Gemma3 model that uses GQA (e.g., gemma3_instruct_12b_text where num_query_heads=16 and num_key_value_heads=8).
  3. Run inference (generate).
 import os
  os.environ["KERAS_BACKEND"] = "torch"
  import keras_hub
  import keras
  
  # Load a GQA model
  model = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_instruct_12b_text")
  model.generate("Hello world")
  

Errors

100%|██████████| 969/969 [00:00<00:00, 3.29MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_instruct_12b_text/3/download/task.json...
100%|██████████| 3.24k/3.24k [00:00<00:00, 10.6MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_instruct_12b_text/3/download/assets/tokenizer/vocabulary.spm...
100%|██████████| 4.47M/4.47M [00:02<00:00, 2.19MB/s]
/usr/local/lib/python3.12/dist-packages/keras/src/backend/torch/core.py:609: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  outputs[slices] = updates
/usr/local/lib/python3.12/dist-packages/keras/src/backend/torch/core.py:594: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return inputs[slices]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipython-input-1868488927.py in <cell line: 0>()
      6 # Load a GQA model
      7 model = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_instruct_12b_text",load_weights=False)
----> 8 model.generate("Hello world")

14 frames
/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/gemma3/gemma3_causal_lm.py in generate(self, inputs, max_length, stop_token_ids, strip_prompt)
    371             ]
    372 
--> 373         return super().generate(
    374             inputs,
    375             max_length=max_length,

/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/causal_lm.py in generate(self, inputs, max_length, stop_token_ids, strip_prompt)
    398             outputs = [strip_prompt_function(generate(x), x) for x in inputs]
    399         else:
--> 400             outputs = [generate(x) for x in inputs]
    401 
    402         if self.preprocessor is not None:

/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/causal_lm.py in generate(x)
    356 
    357         def generate(x):
--> 358             return generate_function(x, stop_token_ids=stop_token_ids)
    359 
    360         def strip_prompt_function(x, prompt):

/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/causal_lm.py in wrapped_generate_function(inputs, stop_token_ids)
    152             ):
    153                 with torch.no_grad():
--> 154                     return self.generate_step(inputs, stop_token_ids)
    155 
    156             self.generate_function = wrapped_generate_function

/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/gemma3/gemma3_causal_lm.py in generate_step(self, inputs, stop_token_ids)
    281 
    282         # Create and seed cache with a single forward pass.
--> 283         hidden_states, cache = self._build_cache(
    284             token_ids,
    285             img_embeddings,

/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/gemma3/gemma3_causal_lm.py in _build_cache(self, token_ids, img_embeddings, vision_mask, padding_mask, vision_indices)
    216         cache = ops.zeros(shape, dtype=self.compute_dtype)
    217         # Seed the cache.
--> 218         logits, hidden_states, cache = self.call_with_cache(
    219             token_ids=token_ids,
    220             img_embeddings=img_embeddings,

/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/gemma3/gemma3_causal_lm.py in call_with_cache(self, token_ids, cache, cache_update_index, img_embeddings, vision_mask, padding_mask, vision_indices, cache_update_mask)
    182         for i, transformer_layer in enumerate(self.backbone.transformer_layers):
    183             current_cache = cache[:, i, ...]
--> 184             x, next_cache = transformer_layer(
    185                 x,
    186                 cache=current_cache,

/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
    120             # To get the full stack trace, call:
    121             # `keras.config.disable_traceback_filtering()`
--> 122             raise e.with_traceback(filtered_tb) from None
    123         finally:
    124             del filtered_tb

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1773             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774         else:
-> 1775             return self._call_impl(*args, **kwargs)
   1776 
   1777     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1784                 or _global_backward_pre_hooks or _global_backward_hooks
   1785                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786             return forward_call(*args, **kwargs)
   1787 
   1788         result = None

/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/gemma3/gemma3_decoder_block.py in call(self, x, padding_mask, vision_mask, cache, cache_update_index, cache_update_mask)
    262         )
    263         if cache is not None:
--> 264             attention, new_cache = self.attention(
    265                 normalized_x,
    266                 attention_mask=attention_mask,

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1773             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774         else:
-> 1775             return self._call_impl(*args, **kwargs)
   1776 
   1777     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1784                 or _global_backward_pre_hooks or _global_backward_hooks
   1785                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786             return forward_call(*args, **kwargs)
   1787 
   1788         result = None

/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/gemma3/gemma3_attention.py in call(self, x, attention_mask, cache, cache_update_index, cache_update_mask, training)
    390             value = self.value_dense(x)
    391 
--> 392         attention_vec = self._compute_attention(
    393             query,
    394             key,

/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/gemma3/gemma3_attention.py in _compute_attention(self, q, k, v, attention_mask, training, cache_update_index)
    199             else:
    200                 kwargs = {}
--> 201             return ops.dot_product_attention(
    202                 query=q,
    203                 key=k,

RuntimeError: Exception encountered when calling CachedGemma3Attention.call().

The size of tensor a (16) must match the size of tensor b (8) at non-singleton dimension 1

Arguments received by CachedGemma3Attention.call():
  • x=torch.Tensor(shape=torch.Size([1, 1024, 3840]), dtype=float32)
  • attention_mask=torch.Tensor(shape=torch.Size([1, 1024, 1024]), dtype=int32)
  • cache=torch.Tensor(shape=torch.Size([1, 2, 1024, 8, 256]), dtype=float32)
  • cache_update_index=0
  • cache_update_mask=None
  • training=False

Attached colab:
https://colab.research.google.com/gist/pctablet505/aa35bea5e6d2d7fc00dfb4d49951675f/gemma3-bug.ipynb#scrollTo=K49dC2hOoOSB

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions