Skip to content

Commit 9fcafa7

Browse files
authored
Add imagenet prediction decoder (#1848)
1 parent adf6dc7 commit 9fcafa7

File tree

6 files changed

+1167
-3
lines changed

6 files changed

+1167
-3
lines changed

keras_hub/api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from keras_hub.api import models
2424
from keras_hub.api import samplers
2525
from keras_hub.api import tokenizers
26+
from keras_hub.api import utils
2627
from keras_hub.src.utils.preset_utils import upload_preset
2728
from keras_hub.src.version_utils import __version__
2829
from keras_hub.src.version_utils import version

keras_hub/api/utils/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2024 The KerasHub Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""DO NOT EDIT.
15+
16+
This file was autogenerated. Do not edit it by hand,
17+
since your modifications would be overwritten.
18+
"""
19+
20+
from keras_hub.src.utils.imagenet.imagenet_utils import (
21+
decode_imagenet_predictions,
22+
)

keras_hub/src/models/resnet/resnet_image_classifier.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ class ResNetImageClassifier(ImageClassifier):
4545
```python
4646
# Load preset and train
4747
images = np.ones((2, 224, 224, 3), dtype="float32")
48-
classifier = keras_hub.models.ResNetImageClassifier.from_preset("resnet50")
48+
classifier = keras_hub.models.ResNetImageClassifier.from_preset(
49+
"resnet_50_imagenet"
50+
)
4951
classifier.predict(images)
5052
```
5153
@@ -54,13 +56,17 @@ class ResNetImageClassifier(ImageClassifier):
5456
# Load preset and train
5557
images = np.ones((2, 224, 224, 3), dtype="float32")
5658
labels = [0, 3]
57-
classifier = keras_hub.models.ResNetImageClassifier.from_preset("resnet50")
59+
classifier = keras_hub.models.ResNetImageClassifier.from_preset(
60+
"resnet_50_imagenet"
61+
)
5862
classifier.fit(x=images, y=labels, batch_size=2)
5963
```
6064
6165
Call `fit()` with custom loss, optimizer and backbone.
6266
```python
63-
classifier = keras_hub.models.ResNetImageClassifier.from_preset("resnet50")
67+
classifier = keras_hub.models.ResNetImageClassifier.from_preset(
68+
"resnet_50_imagenet"
69+
)
6470
classifier.compile(
6571
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
6672
optimizer=keras.optimizers.Adam(5e-5),
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)