Skip to content

Commit 0f35d5e

Browse files
committed
Fix saved classifier models from before 0.14 (#1839)
We switched the class name for `XXClassifier` models to `XXTextClasssifier`. However, a saved classifier before 0.16 will still be looking for the class under the old name. This updates our export helper to also registered the old name, so we can restore to the new class when loading the model. I also try to improve our error messages when we do encounter an unrecognized class.
1 parent 99df05b commit 0f35d5e

File tree

4 files changed

+41
-3
lines changed

4 files changed

+41
-3
lines changed

keras_nlp/src/api_export.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,16 @@
2222
namex = None
2323

2424

25-
def maybe_register_serializable(symbol):
25+
def maybe_register_serializable(path, symbol):
26+
# If we have multiple export names, actually make sure to register these
27+
# first. This makes sure we have a backward compat mapping of old serialized
28+
# name to new class.
29+
if isinstance(path, (list, tuple)):
30+
for name in path:
31+
name = name.split(".")[-1]
32+
keras.saving.register_keras_serializable(
33+
package="keras_nlp", name=name
34+
)(symbol)
2635
if isinstance(symbol, types.FunctionType) or hasattr(symbol, "get_config"):
2736
keras.saving.register_keras_serializable(package="keras_nlp")(symbol)
2837

@@ -34,7 +43,7 @@ def __init__(self, path):
3443
super().__init__(package="keras_nlp", path=path)
3544

3645
def __call__(self, symbol):
37-
maybe_register_serializable(symbol)
46+
maybe_register_serializable(self.path, symbol)
3847
return super().__call__(symbol)
3948

4049
else:

keras_nlp/src/models/bert/bert_text_classifier_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,15 @@ def test_saved_model(self):
6767
input_data=self.input_data,
6868
)
6969

70+
@pytest.mark.large
71+
def test_smallest_preset(self):
72+
self.run_preset_test(
73+
cls=BertTextClassifier,
74+
preset="bert_tiny_en_uncased_sst2",
75+
input_data=self.input_data,
76+
expected_output_shape=(2, 2),
77+
)
78+
7079
@pytest.mark.extra_large
7180
def test_all_presets(self):
7281
for preset in BertTextClassifier.presets:

keras_nlp/src/utils/preset_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,16 @@ def load_serialized_object(config, **kwargs):
582582

583583
def check_config_class(config):
584584
"""Validate a preset is being loaded on the correct class."""
585-
return keras.saving.get_registered_object(config["registered_name"])
585+
registered_name = config["registered_name"]
586+
cls = keras.saving.get_registered_object(registered_name)
587+
if cls is None:
588+
raise ValueError(
589+
f"Attempting to load class {registered_name} with "
590+
"`from_preset()`, but there is no class registered with Keras "
591+
f"for {registered_name}. Make sure to register any custom "
592+
"classes with `register_keras_serializable()`."
593+
)
594+
return cls
586595

587596

588597
def jax_memory_cleanup(layer):

keras_nlp/src/utils/preset_utils_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import json
1516
import os
1617

18+
import keras
1719
import pytest
1820
from absl.testing import parameterized
1921

@@ -38,6 +40,15 @@ def test_preset_errors(self):
3840
with self.assertRaisesRegex(ValueError, "Unknown preset identifier"):
3941
AlbertTextClassifier.from_preset("snaggle://bort/bort/bort")
4042

43+
backbone = BertBackbone.from_preset("bert_tiny_en_uncased")
44+
preset_dir = self.get_temp_dir()
45+
config = keras.utils.serialize_keras_object(backbone)
46+
config["registered_name"] = "keras_nlp>BortBackbone"
47+
with open(os.path.join(preset_dir, CONFIG_FILE), "w") as config_file:
48+
config_file.write(json.dumps(config, indent=4))
49+
with self.assertRaisesRegex(ValueError, "class keras_nlp>BortBackbone"):
50+
BertBackbone.from_preset(preset_dir)
51+
4152
def test_upload_empty_preset(self):
4253
temp_dir = self.get_temp_dir()
4354
empty_preset = os.path.join(temp_dir, "empty")

0 commit comments

Comments
 (0)