Skip to content

Commit 4a41006

Browse files
Gemma 3 fix and patch release. (#2520)
* Fix overflow issue in Gemma3 float16 (#2519) * fix gemma3 decoder block overflow issue * Fix overflow issue in float16 * code reformat * patch release --------- Co-authored-by: Divyashree Sreepathihalli <divyashreepathihalli@gmail.com>
1 parent f1da303 commit 4a41006

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

keras_hub/src/models/gemma3/gemma3_decoder_block.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,11 @@ def call(
251251
cache_update_mask=None,
252252
):
253253
# Note: `vision_mask` is used only for Gemma3.
254+
# If float16, we clamp the input to avoid overflow.
255+
is_float16 = keras.backend.standardize_dtype(x.dtype) == "float16"
256+
if is_float16:
257+
x = ops.clip(x, -65504, 65504)
258+
254259
normalized_x = self.pre_attention_norm(x)
255260
attention_mask = self._compute_attention_mask(
256261
normalized_x, padding_mask, vision_mask, cache, cache_update_index
@@ -275,7 +280,15 @@ def call(
275280
if self.dropout:
276281
attention = self.attention_dropout(attention)
277282

278-
attention_x = x + attention
283+
if is_float16:
284+
attention_x = ops.add(
285+
ops.cast(x, "float32"), ops.cast(attention, "float32")
286+
)
287+
attention_x = ops.clip(attention_x, -65504, 65504)
288+
attention_x = ops.cast(attention_x, "float16")
289+
else:
290+
attention_x = x + attention
291+
279292
normalized_x = self.pre_ffw_norm(attention_x)
280293

281294
x1 = self.gating_ffw(normalized_x)
@@ -286,7 +299,14 @@ def call(
286299
if self.use_post_ffw_norm:
287300
x = self.post_ffw_norm(x)
288301

289-
x = x + attention_x
302+
if is_float16:
303+
x = ops.add(
304+
ops.cast(x, "float32"), ops.cast(attention_x, "float32")
305+
)
306+
x = ops.clip(x, -65504, 65504)
307+
x = ops.cast(x, "float16")
308+
else:
309+
x = x + attention_x
290310

291311
if cache is not None:
292312
return x, new_cache

keras_hub/src/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from keras_hub.src.api_export import keras_hub_export
22

33
# Unique source of truth for the version number.
4-
__version__ = "0.25.0"
4+
__version__ = "0.25.1"
55

66

77
@keras_hub_export("keras_hub.version")

0 commit comments

Comments
 (0)