Skip to content

Commit fef3bc3

Browse files
Add T5Gemma to KerasHub (#2339)
* init: Add initial project structure and files * nit: Fix code format test; and cool AI-generated reviews * refactor: Cleanup and replace incorrect T5LayerNorm with RMSNormalization (Gemma) * fix: Numerics @ atol=1e-4 * refactor: Refactor T5Gemma decoder cache handling * feat: Add checkpoint conversion script * nit: Precise compute_output_shape methods; document head_dim * nit: Propagate dtypes * bug fix + minor cleanup: Fix head_dim default → head_dim from config * perf(jax/tpu): Fused kernel optim for TPU backend + get_config() args * cleanup: Slight refactor * fix: Enable mixed precision and quantization tests * feat: Add support for asymmetrical presets (only invariants included) * refactor: Address reviews - presets will be handled post D-FINE * feat: Support direct loading of Hugging Face checkpoints * ✅ Yayy: Generate outputs identical, hidden states match within 1e-3 * preset test: Register and test a preset (to be replaced later by the team with the full set) * nit: Sharded weights don’t include `model.weights.h5` * nits: Address reviews + replace gated model
1 parent 61fda7f commit fef3bc3

18 files changed

+3202
-0
lines changed

keras_hub/api/models/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,18 @@
615615
T5Preprocessor as T5Preprocessor,
616616
)
617617
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer
618+
from keras_hub.src.models.t5gemma.t5gemma_backbone import (
619+
T5GemmaBackbone as T5GemmaBackbone,
620+
)
621+
from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm import (
622+
T5GemmaSeq2SeqLM as T5GemmaSeq2SeqLM,
623+
)
624+
from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm_preprocessor import (
625+
T5GemmaSeq2SeqLMPreprocessor as T5GemmaSeq2SeqLMPreprocessor,
626+
)
627+
from keras_hub.src.models.t5gemma.t5gemma_tokenizer import (
628+
T5GemmaTokenizer as T5GemmaTokenizer,
629+
)
618630
from keras_hub.src.models.task import Task as Task
619631
from keras_hub.src.models.text_classifier import TextClassifier as Classifier
620632
from keras_hub.src.models.text_classifier import (

keras_hub/api/tokenizers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@
8888
SigLIPTokenizer as SigLIPTokenizer,
8989
)
9090
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer
91+
from keras_hub.src.models.t5gemma.t5gemma_tokenizer import (
92+
T5GemmaTokenizer as T5GemmaTokenizer,
93+
)
9194
from keras_hub.src.models.whisper.whisper_tokenizer import (
9295
WhisperTokenizer as WhisperTokenizer,
9396
)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone
2+
from keras_hub.src.models.t5gemma.t5gemma_presets import backbone_presets
3+
from keras_hub.src.utils.preset_utils import register_presets
4+
5+
register_presets(backbone_presets, T5GemmaBackbone)

0 commit comments

Comments
 (0)