Skip to content

Commit 3f80160

Browse files
authored
Fix dtype when creating a task with a task.json (#2007)
1 parent cefc7b6 commit 3f80160

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

keras_hub/src/models/task_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ def test_save_to_preset(self):
147147
new_out = restored_task.backbone.predict(data)
148148
self.assertAllClose(ref_out, new_out)
149149

150+
# Check setting dtype.
151+
restored_task = TextClassifier.from_preset(save_dir, dtype="float16")
152+
self.assertEqual("float16", restored_task.backbone.dtype_policy.name)
153+
150154
@pytest.mark.large
151155
def test_save_to_preset_custom_backbone_and_preprocessor(self):
152156
preprocessor = keras.layers.Rescaling(1 / 255.0)

keras_hub/src/utils/preset_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,12 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
658658
cls, load_weights, load_task_weights, **kwargs
659659
)
660660
# We found a `task.json` with a complete config for our class.
661+
# Forward backbone args.
662+
backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs)
663+
if "backbone" in task_config["config"]:
664+
backbone_config = task_config["config"]["backbone"]["config"]
665+
backbone_config = {**backbone_config, **backbone_kwargs}
666+
task_config["config"]["backbone"]["config"] = backbone_config
661667
task = load_serialized_object(task_config, **kwargs)
662668
if task.preprocessor and hasattr(
663669
task.preprocessor, "load_preset_assets"

0 commit comments

Comments
 (0)