Skip to content

Commit 684a2eb

Browse files
committed
Port models to core (#1119)
* Port models to core * Proper seed generation for jax * Don't test metrics yet (for a separate PR) * Add all model variables to the jax state mapping We want to avoid a bug in the case that looks like model.generate() mode.fit() model.generate() In this case we need to be careful to not pull in the cached variable state at generation compile time. * Address Ian's comments * Add TODO's for revers embedding * Run pytest on the entirety of keras-nlp * Misc cleanups * Mark docstring tests tf only * Last failing doctest
1 parent 6dfb247 commit 684a2eb

File tree

130 files changed

+1072
-841
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

130 files changed

+1072
-841
lines changed

.github/workflows/actions.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ jobs:
102102
env:
103103
KERAS_BACKEND: ${{ matrix.backend }}
104104
run: |
105-
pytest --run_large keras_nlp/layers/modeling keras_nlp/samplers keras_nlp/tokenizers keras_nlp/metrics
105+
pytest keras_nlp/
106106
format:
107107
name: Check the code format
108108
runs-on: ubuntu-latest

keras_nlp/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,12 @@ def pytest_collection_modifyitems(config, items):
122122
item.add_marker(skip_tf_only)
123123

124124

125+
# Disable traceback filtering for quicker debugging of tests failures.
126+
tf.debugging.disable_traceback_filtering()
125127
if backend_config.multi_backend():
126128
keras.config.disable_traceback_filtering()
127129

128-
tf.debugging.disable_traceback_filtering()
130+
# One off setup for dtensor tests.
131+
if not backend_config.multi_backend():
132+
keras.backend.experimental.enable_tf_random_generator()
133+
keras.utils.set_random_seed(1337)

keras_nlp/metrics/rouge_l.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,13 @@ class RougeL(RougeBase):
102102
103103
3. Pass the metric to `model.compile()`.
104104
>>> inputs = keras.Input(shape=(), dtype='string')
105-
>>> outputs = tf.strings.lower(inputs)
105+
>>> outputs = keras.layers.Identity()(inputs)
106106
>>> model = keras.Model(inputs, outputs)
107107
>>> model.compile(metrics=[keras_nlp.metrics.RougeL()])
108-
>>> x = tf.constant(["HELLO THIS IS FUN"])
108+
>>> y_pred = x = tf.constant(["hello this is fun"])
109109
>>> y = tf.constant(["hello this is awesome"])
110-
>>> metric_dict = model.evaluate(x, y, return_dict=True)
111-
>>> metric_dict["f1_score"]
112-
0.75
110+
>>> model.compute_metrics(x, y, y_pred, sample_weight=None)["f1_score"]
111+
0.75
113112
"""
114113

115114
def __init__(

keras_nlp/metrics/rouge_n.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,12 @@ class RougeN(RougeBase):
121121
122122
3. Pass the metric to `model.compile()`.
123123
>>> inputs = keras.Input(shape=(), dtype='string')
124-
>>> outputs = tf.strings.lower(inputs)
124+
>>> outputs = keras.layers.Identity()(inputs)
125125
>>> model = keras.Model(inputs, outputs)
126126
>>> model.compile(metrics=[keras_nlp.metrics.RougeN()])
127-
>>> x = tf.constant(["HELLO THIS IS FUN"])
127+
>>> y_pred = x = tf.constant(["hello this is fun"])
128128
>>> y = tf.constant(["hello this is awesome"])
129-
>>> metric_dict = model.evaluate(x, y, return_dict=True)
130-
>>> metric_dict["f1_score"]
129+
>>> model.compute_metrics(x, y, y_pred, sample_weight=None)["f1_score"]
131130
0.6666666865348816
132131
"""
133132

keras_nlp/models/albert/albert_backbone_test.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import tensorflow as tf
2020

2121
from keras_nlp.backend import keras
22+
from keras_nlp.backend import ops
2223
from keras_nlp.models.albert.albert_backbone import AlbertBackbone
2324
from keras_nlp.tests.test_case import TestCase
2425

@@ -38,9 +39,9 @@ def setUp(self):
3839
)
3940
self.batch_size = 8
4041
self.input_batch = {
41-
"token_ids": tf.ones((2, 5), dtype="int32"),
42-
"segment_ids": tf.ones((2, 5), dtype="int32"),
43-
"padding_mask": tf.ones((2, 5), dtype="int32"),
42+
"token_ids": ops.ones((2, 5), dtype="int32"),
43+
"segment_ids": ops.ones((2, 5), dtype="int32"),
44+
"padding_mask": ops.ones((2, 5), dtype="int32"),
4445
}
4546

4647
self.input_dataset = tf.data.Dataset.from_tensor_slices(
@@ -57,9 +58,9 @@ def test_name(self):
5758
def test_variable_sequence_length_call_albert(self):
5859
for seq_length in (2, 3, 4):
5960
input_data = {
60-
"token_ids": tf.ones((2, seq_length), dtype="int32"),
61-
"segment_ids": tf.ones((2, seq_length), dtype="int32"),
62-
"padding_mask": tf.ones((2, seq_length), dtype="int32"),
61+
"token_ids": ops.ones((2, seq_length), dtype="int32"),
62+
"segment_ids": ops.ones((2, seq_length), dtype="int32"),
63+
"padding_mask": ops.ones((2, seq_length), dtype="int32"),
6364
}
6465
self.backbone(input_data)
6566

@@ -121,9 +122,9 @@ def setUp(self):
121122
)
122123

123124
self.input_batch = {
124-
"token_ids": tf.ones((8, 128), dtype="int32"),
125-
"segment_ids": tf.ones((8, 128), dtype="int32"),
126-
"padding_mask": tf.ones((8, 128), dtype="int32"),
125+
"token_ids": ops.ones((8, 128), dtype="int32"),
126+
"segment_ids": ops.ones((8, 128), dtype="int32"),
127+
"padding_mask": ops.ones((8, 128), dtype="int32"),
127128
}
128129
self.input_dataset = tf.data.Dataset.from_tensor_slices(
129130
self.input_batch

keras_nlp/models/albert/albert_classifier.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
2323
from keras_nlp.models.albert.albert_presets import backbone_presets
2424
from keras_nlp.models.task import Task
25-
from keras_nlp.utils.keras_utils import is_xla_compatible
2625
from keras_nlp.utils.python_utils import classproperty
2726

2827

@@ -192,7 +191,7 @@ def __init__(
192191
),
193192
optimizer=keras.optimizers.Adam(5e-5),
194193
metrics=[keras.metrics.SparseCategoricalAccuracy()],
195-
jit_compile=is_xla_compatible(self),
194+
jit_compile=True,
196195
)
197196

198197
def get_config(self):

keras_nlp/models/albert/albert_classifier_test.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import tensorflow as tf
2222

2323
from keras_nlp.backend import keras
24+
from keras_nlp.backend import ops
2425
from keras_nlp.models.albert.albert_backbone import AlbertBackbone
2526
from keras_nlp.models.albert.albert_classifier import AlbertClassifier
2627
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
@@ -77,15 +78,13 @@ def setUp(self):
7778
activation=keras.activations.softmax,
7879
)
7980

80-
self.raw_batch = tf.constant(
81-
[
82-
"the quick brown fox.",
83-
"the slow brown fox.",
84-
]
85-
)
81+
self.raw_batch = [
82+
"the quick brown fox.",
83+
"the slow brown fox.",
84+
]
8685
self.preprocessed_batch = self.preprocessor(self.raw_batch)
8786
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
88-
(self.raw_batch, tf.ones((2,)))
87+
(self.raw_batch, ops.ones((2,)))
8988
).batch(2)
9089
self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor)
9190

@@ -99,7 +98,7 @@ def test_classifier_predict(self):
9998
# Assert predictions match.
10099
self.assertAllClose(preds1, preds2)
101100
# Assert valid softmax output.
102-
self.assertAllClose(tf.reduce_sum(preds2, axis=-1), [1.0, 1.0])
101+
self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0])
103102

104103
def test_classifier_fit(self):
105104
self.classifier.fit(self.raw_dataset)

keras_nlp/models/albert/albert_masked_lm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
)
2727
from keras_nlp.models.albert.albert_presets import backbone_presets
2828
from keras_nlp.models.task import Task
29-
from keras_nlp.utils.keras_utils import is_xla_compatible
3029
from keras_nlp.utils.python_utils import classproperty
3130

3231

@@ -135,7 +134,7 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
135134
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
136135
optimizer=keras.optimizers.Adam(5e-5),
137136
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
138-
jit_compile=is_xla_compatible(self),
137+
jit_compile=True,
139138
)
140139

141140
@classproperty

keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def test_serialization(self):
152152
)
153153

154154
@pytest.mark.large
155+
@pytest.mark.tf_only
155156
def test_saved_model(self):
156157
input_data = tf.constant(["the quick brown fox"])
157158

keras_nlp/models/albert/albert_masked_lm_test.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,12 @@ def setUp(self):
8585
preprocessor=None,
8686
)
8787

88-
self.raw_batch = tf.constant(
89-
[
90-
"quick brown fox",
91-
"eagle flew over fox",
92-
"the eagle flew quick",
93-
"a brown eagle",
94-
]
95-
)
88+
self.raw_batch = [
89+
"quick brown fox",
90+
"eagle flew over fox",
91+
"the eagle flew quick",
92+
"a brown eagle",
93+
]
9694
self.preprocessed_batch = self.preprocessor(self.raw_batch)[0]
9795
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
9896
self.raw_batch

0 commit comments

Comments
 (0)