diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py
index 0a71dbcace..6425d4c0f0 100644
--- a/keras_hub/api/models/__init__.py
+++ b/keras_hub/api/models/__init__.py
@@ -421,6 +421,16 @@
from keras_hub.src.models.phi3.phi3_tokenizer import (
Phi3Tokenizer as Phi3Tokenizer,
)
+from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone as Phi4Backbone
+from keras_hub.src.models.phi4.phi4_causal_lm import (
+ Phi4CausalLM as Phi4CausalLM,
+)
+from keras_hub.src.models.phi4.phi4_causal_lm_preprocessor import (
+ Phi4CausalLMPreprocessor as Phi4CausalLMPreprocessor,
+)
+from keras_hub.src.models.phi4.phi4_tokenizer import (
+ Phi4Tokenizer as Phi4Tokenizer,
+)
from keras_hub.src.models.preprocessor import Preprocessor as Preprocessor
from keras_hub.src.models.qwen.qwen_backbone import (
QwenBackbone as Qwen2Backbone,
diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py
index 082078184f..822840ee61 100644
--- a/keras_hub/api/tokenizers/__init__.py
+++ b/keras_hub/api/tokenizers/__init__.py
@@ -68,6 +68,9 @@
from keras_hub.src.models.phi3.phi3_tokenizer import (
Phi3Tokenizer as Phi3Tokenizer,
)
+from keras_hub.src.models.phi4.phi4_tokenizer import (
+ Phi4Tokenizer as Phi4Tokenizer,
+)
from keras_hub.src.models.qwen.qwen_tokenizer import (
QwenTokenizer as Qwen2Tokenizer,
)
diff --git a/keras_hub/src/models/phi4/__init__.py b/keras_hub/src/models/phi4/__init__.py
new file mode 100644
index 0000000000..c03faf3c17
--- /dev/null
+++ b/keras_hub/src/models/phi4/__init__.py
@@ -0,0 +1 @@
+# TODO: Add a register_presets call once phi4_presets.py is implemented.
diff --git a/keras_hub/src/models/phi4/phi4_backbone.py b/keras_hub/src/models/phi4/phi4_backbone.py
new file mode 100644
index 0000000000..23328c3f78
--- /dev/null
+++ b/keras_hub/src/models/phi4/phi4_backbone.py
@@ -0,0 +1,61 @@
+from keras_hub.src.api_export import keras_hub_export
+from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone
+
+
+@keras_hub_export("keras_hub.models.Phi4Backbone")
+class Phi4Backbone(Phi3Backbone):
+ """Phi-4 core network with hyperparameters.
+
+ This network implements a Transformer-based decoder network,
+ Phi-4, as described in ["Phi-4 Technical Report"](https://arxiv.org/pdf/2412.08905).
+ It includes the embedding lookups and transformer layers.
+
+ The default constructor gives a fully customizable, randomly initialized
+ phi-4 model with any number of layers, heads, and embedding
+ dimensions. To load preset architectures and weights, use the `from_preset`
+ constructor.
+
+ Note that the defaults here are the Phi-3 defaults, because the Phi-4 model
+ follows the Phi-3-medium architecture but with different hyper-parameters.
+ Use `keras_hub.models.Backbone.from_preset` to get the Phi-4 defaults.
+
+ Args:
+ vocabulary_size: int. The size of the token vocabulary.
+ num_layers: int. The number of transformer layers.
+ hidden_dim: int. The size of the embeddings and the hidden states of
+ the transformer layers.
+ intermediate_dim: int. The output dimension of the first Dense layer in
+ a three-layer feedforward network for each transformer.
+ num_query_heads: int. The number of query attention heads for each
+ transformer layer.
+ num_key_value_heads: int. The number of key and value attention heads
+ for each transformer layer.
+ layer_norm_epsilon: float, optional. Epsilon for the RMS layernorm
+ layers in the transformer decoder. Defaults to `1e-6`.
+ dropout:: float, optional. Dropout probability for the Transformer
+ decoder.
+ max_sequence_length: int, optional. The maximum sequence length
+ that this model might ever be used with. Defaults to `4096`.
+ pretraining_sequence_length: int, optional. The maximum sequence length
+ that the model was pretrained with. Defaults to `4096`.
+ rope_max_wavelength: int, optional. The maximum angular wavelength of
+ the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
+ rope_scaling_type: str, optional. The type of the rope scaling. Can be
+ either `None` or `"su"`. `None` is for no rope scaling, `"su"` is
+ for SuScaled rope, `"su"` is used when `max_sequence_length` is
+ larger than `original_max_sequence_length`. Defaults to `None`.
+ rope_scaling_short_factor: list[float]. List of factors used to adjust
+ rope frequencies when the `rope_scaling_type` is `"su"`. List must
+ be of length `hidden_dim//num_query_heads//2`. It is used when
+ `sequence_length` is smaller than `pretraining_sequence_length`.
+ Defaults to `None`.
+ rope_scaling_long_factor: list[float]. List of factors used to adjust
+ rope frequencies when the `rope_scaling_type` is `"su"`. List must
+ be of length `hidden_dim//num_query_heads//2`. It is used when
+ `sequence_length` is larger than `pretraining_sequence_length`.
+ Defaults to `None`.
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
+ for model computations and weights. Note that some computations,
+ such as softmax and layer normalization, will always be done at
+ float32 precision regardless of dtype.
+ """
diff --git a/keras_hub/src/models/phi4/phi4_backbone_test.py b/keras_hub/src/models/phi4/phi4_backbone_test.py
new file mode 100644
index 0000000000..adac55054e
--- /dev/null
+++ b/keras_hub/src/models/phi4/phi4_backbone_test.py
@@ -0,0 +1,92 @@
+import pytest
+from keras import ops
+
+from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone
+from keras_hub.src.tests.test_case import TestCase
+
+
+class Phi4Test(TestCase):
+ def setUp(self):
+ self.init_kwargs = {
+ "vocabulary_size": 10,
+ "num_layers": 2,
+ "num_query_heads": 4,
+ "num_key_value_heads": 2,
+ "hidden_dim": 8,
+ "intermediate_dim": 8,
+ }
+ self.su_rotary_init_kwargs = {
+ "vocabulary_size": 10,
+ "num_layers": 2,
+ "num_query_heads": 2,
+ "num_key_value_heads": 1,
+ "hidden_dim": 8,
+ "intermediate_dim": 12,
+ "max_sequence_length": 10,
+ "pretraining_sequence_length": 5,
+ "rope_scaling_type": "su",
+ "rope_scaling_short_factor": [1.2, 1.4],
+ "rope_scaling_long_factor": [0.8, 0.6],
+ }
+ self.input_data = {
+ "token_ids": ops.ones((2, 5), dtype="int32"),
+ "padding_mask": ops.ones((2, 5), dtype="int32"),
+ }
+
+ def test_backbone_basics(self):
+ self.run_backbone_test(
+ cls=Phi4Backbone,
+ init_kwargs=self.init_kwargs,
+ input_data=self.input_data,
+ expected_output_shape=(2, 5, 8),
+ )
+
+ @pytest.mark.large
+ def test_saved_model(self):
+ self.run_model_saving_test(
+ cls=Phi4Backbone,
+ init_kwargs=self.init_kwargs,
+ input_data=self.input_data,
+ )
+
+ def test_backbone_basics_with_su_rotary(self):
+ self.run_backbone_test(
+ cls=Phi4Backbone,
+ init_kwargs=self.su_rotary_init_kwargs,
+ input_data=self.input_data,
+ expected_output_shape=(2, 5, 8),
+ )
+
+ @pytest.mark.large
+ def test_saved_model_with_su_rotary(self):
+ self.run_model_saving_test(
+ cls=Phi4Backbone,
+ init_kwargs=self.su_rotary_init_kwargs,
+ input_data=self.input_data,
+ )
+
+ @pytest.mark.extra_large
+ def test_smallest_preset(self):
+ self.run_preset_test(
+ cls=Phi4Backbone,
+ preset="phi4_mini_4k_instruct_en",
+ input_data={
+ "token_ids": ops.array([[1, 450, 4996, 1701, 29916, 29889]]),
+ "padding_mask": ops.ones((1, 6), dtype="int32"),
+ },
+ expected_output_shape=(1, 6, 3072),
+ # The forward pass from a preset should be stable!
+ # Reference values computed using PyTorch HF model.
+ expected_partial_output=ops.array(
+ [-0.21222, 0.04004, -0.02759, 0.02200]
+ ),
+ )
+
+ @pytest.mark.extra_large
+ def test_all_presets(self):
+ for preset in Phi4Backbone.presets:
+ self.run_preset_test(
+ cls=Phi4Backbone,
+ preset=preset,
+ input_data=self.input_data,
+ )
diff --git a/keras_hub/src/models/phi4/phi4_causal_lm.py b/keras_hub/src/models/phi4/phi4_causal_lm.py
new file mode 100644
index 0000000000..8f014d81e2
--- /dev/null
+++ b/keras_hub/src/models/phi4/phi4_causal_lm.py
@@ -0,0 +1,33 @@
+from keras_hub.src.api_export import keras_hub_export
+from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM
+from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone
+from keras_hub.src.models.phi4.phi4_causal_lm_preprocessor import (
+ Phi4CausalLMPreprocessor,
+)
+
+
+@keras_hub_export("keras_hub.models.Phi4CausalLM")
+class Phi4CausalLM(Phi3CausalLM):
+ """An end-to-end Phi4 model for causal language modeling.
+
+ A causal language model (LM) predicts the next token based on previous
+ tokens. This task setup can be used to train the model unsupervised on
+ plain text input, or to autoregressively generate plain text similar to
+ the data used for training. This task can be used for pre-training or
+ fine-tuning a Phi-4 model, simply by calling `fit()`.
+
+ This model has a `generate()` method, which generates text based on a
+ prompt. The generation strategy used is controlled by an additional
+ `sampler` argument on `compile()`. You can recompile the model with
+ different `keras_hub.samplers` objects to control the generation. By
+ default, `"top_k"` sampling will be used.
+
+ Args:
+ backbone: A `keras_hub.models.Phi4Backbone` instance.
+ preprocessor: A `keras_hub.models.Phi4CausalLMPreprocessor` or `None`.
+ If `None`, this model will not apply preprocessing, and inputs
+ should be preprocessed before calling the model.
+ """
+
+ backbone_cls = Phi4Backbone
+ preprocessor_cls = Phi4CausalLMPreprocessor
diff --git a/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor.py b/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor.py
new file mode 100644
index 0000000000..63d874daa3
--- /dev/null
+++ b/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor.py
@@ -0,0 +1,76 @@
+from keras_hub.src.api_export import keras_hub_export
+from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
+from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone
+from keras_hub.src.models.phi4.phi4_tokenizer import Phi4Tokenizer
+
+
+@keras_hub_export("keras_hub.models.Phi4CausalLMPreprocessor")
+class Phi4CausalLMPreprocessor(CausalLMPreprocessor):
+ """Phi4 Causal LM preprocessor.
+
+ This preprocessing layer is meant for use with
+ `keras_hub.models.Phi4CausalLM`. By default, it will take in batches of
+ strings, and return outputs in a `(x, y, sample_weight)` format, where the
+ `y` label is the next token id in the `x` sequence.
+
+ For use with generation, the layer also exposes two methods
+ `generate_preprocess()` and `generate_postprocess()`. When this preprocessor
+ is attached to a `keras_hub.models.Phi4CausalLM` instance, these methods
+ will be called implicitly in `generate()`. They can also be called
+ standalone (e.g. to precompute preprocessing inputs for generation in a
+ separate process).
+
+ Args:
+ tokenizer: A `keras_hub.models.Phi4Tokenizer` instance.
+ sequence_length: The length of the packed inputs.
+ add_start_token: If `True`, the preprocessor will prepend the tokenizer
+ start token to each input sequence. Default is `True`.
+ add_end_token: If `True`, the preprocessor will append the tokenizer
+ end token to each input sequence. Default is `False`.
+
+ Call arguments:
+ x: A string, `tf.Tensor` or list of python strings.
+ y: Label data. Should always be `None` as the layer generates labels.
+ sample_weight: Label weights. Should always be `None` as the layer
+ generates label weights.
+ sequence_length: Pass to override the configured `sequence_length` of
+ the layer.
+
+ Examples:
+ ```python
+ # Load the preprocessor from a preset.
+ preprocessor = keras_hub.models.Phi4CausalLMPreprocessor.from_preset(
+ "phi4_mini_4k_instruct_en"
+ )
+
+ # Tokenize and pack a single sentence.
+ sentence = tf.constant("League of legends")
+ preprocessor(sentence)
+ # Same output.
+ preprocessor("League of legends")
+
+ # Tokenize a batch of sentences.
+ sentences = tf.constant(["Taco tuesday", "Fish taco please!"])
+ preprocessor(sentences)
+ # Same output.
+ preprocessor(["Taco tuesday", "Fish taco please!"])
+
+ # Map a dataset to preprocess a single sentence.
+ features = tf.constant(
+ [
+ "Avatar 2 is amazing!",
+ "Well, I am not sure.",
+ ]
+ )
+ labels = tf.constant([1, 0])
+ ds = tf.data.Dataset.from_tensor_slices((features, labels))
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
+
+ # Map a dataset to preprocess unlabled sentences.
+ ds = tf.data.Dataset.from_tensor_slices(features)
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
+ ```
+ """
+
+ backbone_cls = Phi4Backbone
+ tokenizer_cls = Phi4Tokenizer
diff --git a/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor_test.py b/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor_test.py
new file mode 100644
index 0000000000..96f207e817
--- /dev/null
+++ b/keras_hub/src/models/phi4/phi4_causal_lm_preprocessor_test.py
@@ -0,0 +1,92 @@
+import pytest
+
+from keras_hub.src.models.phi4.phi4_causal_lm_preprocessor import (
+ Phi4CausalLMPreprocessor,
+)
+from keras_hub.src.models.phi4.phi4_tokenizer import Phi4Tokenizer
+from keras_hub.src.tests.test_case import TestCase
+
+
+class Phi4CausalLMPreprocessorTest(TestCase):
+ def setUp(self):
+ self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
+ self.vocab += [
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ]
+ self.vocab += ["", "", ""]
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
+ self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
+ self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
+ self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.tokenizer = Phi4Tokenizer(
+ vocabulary=self.vocab, merges=self.merges
+ )
+ self.init_kwargs = {
+ "tokenizer": self.tokenizer,
+ "sequence_length": 10,
+ }
+ # [1, 3, 4, 2, 5]
+ self.input_data = (["airplane at airport"],)
+
+ def test_preprocessor_basics(self):
+ self.run_preprocessor_test(
+ cls=Phi4CausalLMPreprocessor,
+ init_kwargs=self.init_kwargs,
+ input_data=self.input_data,
+ expected_output=(
+ {
+ "token_ids": [[6, 1, 3, 4, 2, 5, 0, 0, 0, 0]],
+ "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
+ },
+ [[1, 3, 4, 2, 5, 0, 0, 0, 0, 7]],
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
+ ),
+ )
+
+ def test_no_start_end_token(self):
+ input_data = ["airplane at airport"] * 4
+
+ preprocessor = Phi4CausalLMPreprocessor(
+ **self.init_kwargs,
+ add_start_token=False,
+ add_end_token=False,
+ )
+ x, y, sw = preprocessor(input_data)
+ self.assertAllEqual(
+ x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0, 0, 0]] * 4
+ )
+ self.assertAllEqual(
+ x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] * 4
+ )
+ self.assertAllEqual(y, [[3, 4, 2, 5, 0, 0, 0, 0, 0, 0]] * 4)
+ self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] * 4)
+
+ def test_generate_preprocess(self):
+ input_data = "airplane at airport"
+ preprocessor = Phi4CausalLMPreprocessor(**self.init_kwargs)
+ x = preprocessor.generate_preprocess(input_data)
+ self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0, 0, 0])
+ self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
+
+ def test_generate_postprocess(self):
+ input_data = {
+ "token_ids": [1, 3, 4, 2, 5, 3, 9, 7, 11, 0],
+ "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
+ }
+ preprocessor = Phi4CausalLMPreprocessor(**self.init_kwargs)
+ x = preprocessor.generate_postprocess(input_data)
+ self.assertAllEqual(x, "airplane at airport")
+
+ @pytest.mark.extra_large
+ def test_all_presets(self):
+ for preset in Phi4CausalLMPreprocessor.presets:
+ self.run_preset_test(
+ cls=Phi4CausalLMPreprocessor,
+ preset=preset,
+ input_data=self.input_data,
+ )
diff --git a/keras_hub/src/models/phi4/phi4_causal_lm_test.py b/keras_hub/src/models/phi4/phi4_causal_lm_test.py
new file mode 100644
index 0000000000..28a7c2e797
--- /dev/null
+++ b/keras_hub/src/models/phi4/phi4_causal_lm_test.py
@@ -0,0 +1,128 @@
+from unittest.mock import patch
+
+import pytest
+from keras import ops
+
+from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone
+from keras_hub.src.models.phi4.phi4_causal_lm import Phi4CausalLM
+from keras_hub.src.models.phi4.phi4_causal_lm_preprocessor import (
+ Phi4CausalLMPreprocessor,
+)
+from keras_hub.src.models.phi4.phi4_tokenizer import Phi4Tokenizer
+from keras_hub.src.tests.test_case import TestCase
+
+
+class Phi4CausalLMTest(TestCase):
+ def setUp(self):
+ # Move to index 0 since the tokenizer sets pad_token_id to 0.
+ self.vocab = ["", "air", "Ġair", "plane", "Ġat", "port"]
+ self.vocab += [
+ "",
+ "",
+ "!",
+ "",
+ "",
+ "",
+ # Necessary since `Phi3CausalLM` requires this in `generate()`
+ "<|end|>",
+ ]
+ self.vocab += ["", "", ""]
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
+ self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
+ self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
+ self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.preprocessor = Phi4CausalLMPreprocessor(
+ Phi4Tokenizer(vocabulary=self.vocab, merges=self.merges),
+ sequence_length=15,
+ )
+ self.vocab_size = self.preprocessor.tokenizer.vocabulary_size()
+ self.backbone = Phi4Backbone(
+ vocabulary_size=self.vocab_size,
+ num_layers=2,
+ num_query_heads=4,
+ num_key_value_heads=2,
+ hidden_dim=8,
+ intermediate_dim=16,
+ )
+ self.init_kwargs = {
+ "preprocessor": self.preprocessor,
+ "backbone": self.backbone,
+ }
+ self.train_data = ([" airplane at airport", " airplane at airport"],)
+ self.input_data = self.preprocessor(*self.train_data)[0]
+
+ def test_causal_lm_basics(self):
+ self.run_task_test(
+ cls=Phi4CausalLM,
+ init_kwargs=self.init_kwargs,
+ train_data=self.train_data,
+ expected_output_shape=(2, 15, self.vocab_size),
+ )
+
+ def test_generate(self):
+ causal_lm = Phi4CausalLM(**self.init_kwargs)
+ # String input.
+ prompt = " airplane at airport"
+ output = causal_lm.generate(" airplane at airport")
+ self.assertTrue(prompt in output)
+ # Int tensor input.
+ prompt_ids = self.preprocessor.generate_preprocess([prompt])
+ causal_lm.preprocessor = None
+ outputs = causal_lm.generate(prompt_ids, stop_token_ids=None)
+ # Assert prompt is in output in token id space.
+ self.assertAllEqual(
+ outputs["token_ids"][:, :5],
+ prompt_ids["token_ids"][:, :5],
+ )
+ self.assertAllEqual(
+ outputs["padding_mask"][:, :5],
+ prompt_ids["padding_mask"][:, :5],
+ )
+
+ def test_early_stopping(self):
+ causal_lm = Phi4CausalLM(**self.init_kwargs)
+ call_with_cache = causal_lm.call_with_cache
+
+ def wrapper(*args, **kwargs):
+ """Modify output logits to always favor end_token_id"""
+ logits, hidden_states, cache = call_with_cache(*args, **kwargs)
+ index = self.preprocessor.tokenizer.end_token_id
+ update = ops.ones_like(logits)[:, :, index] * 1.0e9
+ update = ops.expand_dims(update, axis=-1)
+ logits = ops.slice_update(logits, (0, 0, index), update)
+ return logits, hidden_states, cache
+
+ with patch.object(causal_lm, "call_with_cache", wraps=wrapper):
+ prompt = [" airplane at airport", " airplane"]
+ output = causal_lm.generate(prompt, max_length=7)
+ # We should immediately abort and output the prompt.
+ self.assertEqual(prompt, output)
+
+ def test_generate_compilation(self):
+ causal_lm = Phi4CausalLM(**self.init_kwargs)
+ # Assert we do not recompile with successive calls.
+ causal_lm.generate("the fox")
+ first_fn = causal_lm.generate_function
+ causal_lm.generate("the fox")
+ second_fn = causal_lm.generate_function
+ self.assertEqual(first_fn, second_fn)
+ # Assert we do recompile after compile is called.
+ causal_lm.compile(sampler="greedy")
+ self.assertIsNone(causal_lm.generate_function)
+
+ @pytest.mark.large
+ def test_saved_model(self):
+ self.run_model_saving_test(
+ cls=Phi4CausalLM,
+ init_kwargs=self.init_kwargs,
+ input_data=self.input_data,
+ )
+
+ @pytest.mark.extra_large
+ def test_all_presets(self):
+ for preset in Phi4CausalLM.presets:
+ self.run_preset_test(
+ cls=Phi4CausalLM,
+ preset=preset,
+ input_data=self.input_data,
+ )
diff --git a/keras_hub/src/models/phi4/phi4_tokenizer.py b/keras_hub/src/models/phi4/phi4_tokenizer.py
new file mode 100644
index 0000000000..de0ef3c3ef
--- /dev/null
+++ b/keras_hub/src/models/phi4/phi4_tokenizer.py
@@ -0,0 +1,86 @@
+from keras_hub.src.api_export import keras_hub_export
+from keras_hub.src.models.phi4.phi4_backbone import Phi4Backbone
+from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
+
+
+@keras_hub_export(
+ [
+ "keras_hub.tokenizers.Phi4Tokenizer",
+ "keras_hub.models.Phi4Tokenizer",
+ ]
+)
+class Phi4Tokenizer(BytePairTokenizer):
+ """Phi4 tokenizer using Byte-Pair Encoding subword segmentation.
+
+ This tokenizer class will tokenize raw strings into integer sequences and
+ is based on `keras_hub.tokenizers.BytePairTokenizer`. Unlike the
+ underlying tokenizer, it will check for all special tokens needed by
+ Phi4 models and provides a `from_preset()` method to automatically
+ download a matching vocabulary for a Phi4 preset.
+
+ If input is a batch of strings (rank > 0), the layer will output a
+ `tf.RaggedTensor` where the last dimension of the output is ragged.
+
+ If input is a scalar string (rank == 0), the layer will output a dense
+ `tf.Tensor` with static shape `[None]`.
+
+ Args:
+ vocabulary: string or dict, maps token to integer ids. If it is a
+ string, it should be the file path to a json file.
+ merges: string or list, contains the merge rule. If it is a string,
+ it should be the file path to merge rules. The merge rule file
+ should have one merge rule per line. Every merge rule contains
+ merge entities separated by a space.
+ sequence_length: int. If set, the output will be
+ padded or truncated to the `sequence_length`. Defaults to 100,352
+ based on the [Phi-4 Technical Report](https://arxiv.org/pdf/2412.08905)
+
+ Examples:
+ ```python
+ # Unbatched input.
+ tokenizer = keras_hub.models.Phi4Tokenizer.from_preset(
+ "phi4_mini_4k_instruct_en",
+ )
+ tokenizer("The quick brown fox jumped.")
+
+ # Batched input.
+ tokenizer(["The quick brown fox jumped.", "The fox slept."])
+
+ # Detokenization.
+ tokenizer.detokenize(tokenizer("The quick brown fox jumped."))
+ ```
+
+ # References
+
+ - [Phi-4 tokenizer config](https://huggingface.co/microsoft/phi-4/raw/main/tokenizer.json)
+ """
+
+ backbone_cls = Phi4Backbone
+
+ def __init__(
+ self,
+ vocabulary=None,
+ merges=None,
+ sequence_length=100_352,
+ **kwargs,
+ ):
+ self._add_special_token("", "start_token")
+ self._add_special_token("", "end_token")
+ self._add_special_token("", "pad_token")
+
+ # FIM = Fill-in-the-middle, which uses special tokens to identify
+ # the prefix/middle/suffix part of the input/output for coding tasks.
+ self._add_special_token("", "fim_prefix")
+ self._add_special_token("", "fim_middle")
+ self._add_special_token("", "fix_suffix")
+
+ self._add_special_token("", "input_message_start")
+ self._add_special_token("", "input_message_separator")
+ self._add_special_token("", "input_message_end")
+
+ super().__init__(
+ vocabulary=vocabulary,
+ merges=merges,
+ sequence_length=sequence_length,
+ **kwargs,
+ )
diff --git a/keras_hub/src/models/phi4/phi4_tokenizer_test.py b/keras_hub/src/models/phi4/phi4_tokenizer_test.py
new file mode 100644
index 0000000000..e5bafa416c
--- /dev/null
+++ b/keras_hub/src/models/phi4/phi4_tokenizer_test.py
@@ -0,0 +1,61 @@
+import pytest
+
+from keras_hub.src.models.phi4.phi4_tokenizer import Phi4Tokenizer
+from keras_hub.src.tests.test_case import TestCase
+
+
+class Phi4TokenizerTest(TestCase):
+ def setUp(self):
+ self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
+ self.vocab += [
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ]
+ self.vocab += ["", "", ""]
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
+ self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
+ self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
+ self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.init_kwargs = {
+ "vocabulary": self.vocab,
+ "merges": self.merges,
+ "sequence_length": None,
+ }
+ self.input_data = [
+ " airplane at airport",
+ " airplane airport",
+ ]
+
+ def test_tokenizer_basics(self):
+ self.run_preprocessing_layer_test(
+ cls=Phi4Tokenizer,
+ init_kwargs=self.init_kwargs,
+ input_data=self.input_data,
+ expected_output=[[6, 2, 3, 4, 2, 5, 7, 8], [2, 3, 2, 5]],
+ )
+
+ def test_errors_missing_special_tokens(self):
+ with self.assertRaises(ValueError):
+ Phi4Tokenizer(vocabulary={"foo": 0, "bar": 1}, merges=["fo o"])
+
+ @pytest.mark.large
+ def test_smallest_preset(self):
+ self.run_preset_test(
+ cls=Phi4Tokenizer,
+ preset="phi4_mini_4k_instruct_en",
+ input_data=["The quick brown fox."],
+ expected_output=[[791, 4062, 14198, 39935, 13]],
+ )
+
+ @pytest.mark.extra_large
+ def test_all_presets(self):
+ for preset in Phi4Tokenizer.presets:
+ self.run_preset_test(
+ cls=Phi4Tokenizer,
+ preset=preset,
+ input_data=self.input_data,
+ )