diff --git a/keras_hub/src/models/gemma3/gemma3_causal_lm.py b/keras_hub/src/models/gemma3/gemma3_causal_lm.py index 8fa2811598..c07a4bdbc6 100644 --- a/keras_hub/src/models/gemma3/gemma3_causal_lm.py +++ b/keras_hub/src/models/gemma3/gemma3_causal_lm.py @@ -227,7 +227,7 @@ def _build_cache( ) return hidden_states, cache - def generate_step(self, inputs, stop_token_ids=[106]): + def generate_step(self, inputs, stop_token_ids=None): """A compilable generation function for a single batch of inputs. This function represents the inner, XLA-compilable, generation function @@ -326,11 +326,14 @@ def next(prompt, cache, index): else: # Without early stopping, all locations will have been updated. padding_mask = ops.ones_like(token_ids, dtype="bool") - return { + output_dict = { "token_ids": token_ids, "padding_mask": padding_mask, - "images": images, } + if images is not None: + output_dict["images"] = images + + return output_dict def generate( self,