-
Notifications
You must be signed in to change notification settings - Fork 331
Description
Description
When running Gemma3 (which uses Grouped Query Attention) with the torch backend, ops.dot_product_attention fails with a runtime error due to a shape mismatch between query and key/value heads.
This occurs because PyTorch's scaled_dot_product_attention (used by ops.dot_product_attention when fused opacity is enabled) does not support broadcasting for Grouped Query Attention (GQA) where the number of KV heads > 1 but different from the number of Query heads (e.g., 16 vs 8). It only supports Multi-Query Attention (KV heads = 1) or standard Multi-Head Attention (KV heads = Query heads).
Steps to Reproduce
- Use the torch backend.
- Load a Gemma3 model that uses GQA (e.g., gemma3_instruct_12b_text where num_query_heads=16 and num_key_value_heads=8).
- Run inference (generate).
import os
os.environ["KERAS_BACKEND"] = "torch"
import keras_hub
import keras
# Load a GQA model
model = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_instruct_12b_text")
model.generate("Hello world")
Errors
100%|██████████| 969/969 [00:00<00:00, 3.29MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_instruct_12b_text/3/download/task.json...
100%|██████████| 3.24k/3.24k [00:00<00:00, 10.6MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/gemma3/keras/gemma3_instruct_12b_text/3/download/assets/tokenizer/vocabulary.spm...
100%|██████████| 4.47M/4.47M [00:02<00:00, 2.19MB/s]
/usr/local/lib/python3.12/dist-packages/keras/src/backend/torch/core.py:609: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
outputs[slices] = updates
/usr/local/lib/python3.12/dist-packages/keras/src/backend/torch/core.py:594: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return inputs[slices]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipython-input-1868488927.py in <cell line: 0>()
6 # Load a GQA model
7 model = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_instruct_12b_text",load_weights=False)
----> 8 model.generate("Hello world")
14 frames
/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/gemma3/gemma3_causal_lm.py in generate(self, inputs, max_length, stop_token_ids, strip_prompt)
371 ]
372
--> 373 return super().generate(
374 inputs,
375 max_length=max_length,
/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/causal_lm.py in generate(self, inputs, max_length, stop_token_ids, strip_prompt)
398 outputs = [strip_prompt_function(generate(x), x) for x in inputs]
399 else:
--> 400 outputs = [generate(x) for x in inputs]
401
402 if self.preprocessor is not None:
/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/causal_lm.py in generate(x)
356
357 def generate(x):
--> 358 return generate_function(x, stop_token_ids=stop_token_ids)
359
360 def strip_prompt_function(x, prompt):
/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/causal_lm.py in wrapped_generate_function(inputs, stop_token_ids)
152 ):
153 with torch.no_grad():
--> 154 return self.generate_step(inputs, stop_token_ids)
155
156 self.generate_function = wrapped_generate_function
/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/gemma3/gemma3_causal_lm.py in generate_step(self, inputs, stop_token_ids)
281
282 # Create and seed cache with a single forward pass.
--> 283 hidden_states, cache = self._build_cache(
284 token_ids,
285 img_embeddings,
/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/gemma3/gemma3_causal_lm.py in _build_cache(self, token_ids, img_embeddings, vision_mask, padding_mask, vision_indices)
216 cache = ops.zeros(shape, dtype=self.compute_dtype)
217 # Seed the cache.
--> 218 logits, hidden_states, cache = self.call_with_cache(
219 token_ids=token_ids,
220 img_embeddings=img_embeddings,
/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/gemma3/gemma3_causal_lm.py in call_with_cache(self, token_ids, cache, cache_update_index, img_embeddings, vision_mask, padding_mask, vision_indices, cache_update_mask)
182 for i, transformer_layer in enumerate(self.backbone.transformer_layers):
183 current_cache = cache[:, i, ...]
--> 184 x, next_cache = transformer_layer(
185 x,
186 cache=current_cache,
/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1774 else:
-> 1775 return self._call_impl(*args, **kwargs)
1776
1777 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1784 or _global_backward_pre_hooks or _global_backward_hooks
1785 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786 return forward_call(*args, **kwargs)
1787
1788 result = None
/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/gemma3/gemma3_decoder_block.py in call(self, x, padding_mask, vision_mask, cache, cache_update_index, cache_update_mask)
262 )
263 if cache is not None:
--> 264 attention, new_cache = self.attention(
265 normalized_x,
266 attention_mask=attention_mask,
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1774 else:
-> 1775 return self._call_impl(*args, **kwargs)
1776
1777 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1784 or _global_backward_pre_hooks or _global_backward_hooks
1785 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786 return forward_call(*args, **kwargs)
1787
1788 result = None
/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/gemma3/gemma3_attention.py in call(self, x, attention_mask, cache, cache_update_index, cache_update_mask, training)
390 value = self.value_dense(x)
391
--> 392 attention_vec = self._compute_attention(
393 query,
394 key,
/usr/local/lib/python3.12/dist-packages/keras_hub/src/models/gemma3/gemma3_attention.py in _compute_attention(self, q, k, v, attention_mask, training, cache_update_index)
199 else:
200 kwargs = {}
--> 201 return ops.dot_product_attention(
202 query=q,
203 key=k,
RuntimeError: Exception encountered when calling CachedGemma3Attention.call().
The size of tensor a (16) must match the size of tensor b (8) at non-singleton dimension 1
Arguments received by CachedGemma3Attention.call():
• x=torch.Tensor(shape=torch.Size([1, 1024, 3840]), dtype=float32)
• attention_mask=torch.Tensor(shape=torch.Size([1, 1024, 1024]), dtype=int32)
• cache=torch.Tensor(shape=torch.Size([1, 2, 1024, 8, 256]), dtype=float32)
• cache_update_index=0
• cache_update_mask=None
• training=False
Attached colab:
https://colab.research.google.com/gist/pctablet505/aa35bea5e6d2d7fc00dfb4d49951675f/gemma3-bug.ipynb#scrollTo=K49dC2hOoOSB