Skip to content

Commit a5da750

Browse files
Allow a task preprocessor to be an argument in from_preset. (#1603)
1 parent 4c8d0bc commit a5da750

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

keras_nlp/models/task.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,10 @@ def from_preset(
282282
load_weights=load_weights,
283283
config_overrides=config_overrides,
284284
)
285-
preprocessor = cls.preprocessor_cls.from_preset(preset)
285+
if "preprocessor" in kwargs:
286+
preprocessor = kwargs.pop("preprocessor")
287+
else:
288+
preprocessor = cls.preprocessor_cls.from_preset(preset)
286289
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)
287290

288291
def load_task_weights(self, filepath):

keras_nlp/models/task_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,13 @@ def test_save_to_preset(self):
138138
ref_out = model.predict(data)
139139
new_out = restored_model.predict(data)
140140
self.assertAllEqual(ref_out, new_out)
141+
142+
@pytest.mark.keras_3_only
143+
@pytest.mark.large
144+
def test_none_preprocessor(self):
145+
model = Classifier.from_preset(
146+
"bert_tiny_en_uncased",
147+
preprocessor=None,
148+
num_classes=2,
149+
)
150+
self.assertEqual(model.preprocessor, None)

0 commit comments

Comments
 (0)