Skip to content

Commit cc450d7

Browse files
authored
Migrate VGG16 from legacy to backbone (#2341)
* Add VGG16 to backbone from legacy * Add backbone tests * Add model to __init__.py * Fix code format for vgg16 backbone
1 parent 15db57c commit cc450d7

File tree

4 files changed

+308
-0
lines changed

4 files changed

+308
-0
lines changed

keras_cv/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@
178178
from keras_cv.models.backbones.resnet_v2.resnet_v2_backbone import (
179179
ResNetV2Backbone,
180180
)
181+
from keras_cv.models.backbones.vgg16.vgg16_backbone import VGG16Backbone
181182
from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetBBackbone
182183
from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetHBackbone
183184
from keras_cv.models.backbones.vit_det.vit_det_aliases import ViTDetLBackbone
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2023 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright 2023 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from keras import layers
16+
17+
from keras_cv.models import utils
18+
from keras_cv.models.backbones.backbone import Backbone
19+
20+
21+
class VGG16Backbone(Backbone):
22+
"""
23+
Reference:
24+
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556)
25+
(ICLR 2015)
26+
This class represents Keras Backbone of VGG16 model.
27+
Args:
28+
include_rescaling: bool, whether to rescale the inputs. If set to
29+
True, inputs will be passed through a `Rescaling(1/255.0)` layer.
30+
include_top: bool, whether to include the 3 fully-connected
31+
layers at the top of the network. If provided, num_classes must be
32+
provided.
33+
num_classes: int, optional number of classes to classify images into,
34+
only to be specified if `include_top` is True.
35+
input_shape: tuple, optional shape tuple, defaults to (224, 224, 3).
36+
input_tensor: Tensor, optional Keras tensor (i.e. output of
37+
`layers.Input()`) to use as image input for the model.
38+
pooling: bool, Optional pooling mode for feature extraction
39+
when `include_top` is `False`.
40+
- `None` means that the output of the model will be
41+
the 4D tensor output of the
42+
last convolutional block.
43+
- `avg` means that global average pooling
44+
will be applied to the output of the
45+
last convolutional block, and thus
46+
the output of the model will be a 2D tensor.
47+
- `max` means that global max pooling will
48+
be applied.
49+
classifier_activation:`str` or callable. The activation function to use
50+
on the "top" layer. Ignored unless `include_top=True`. Set
51+
`classifier_activation=None` to return the logits of the "top" layer.
52+
When loading pretrained weights, `classifier_activation` can only
53+
be `None` or `"softmax"`.
54+
name: (Optional) name to pass to the model, defaults to "VGG16".
55+
Returns:
56+
A `keras.Model` instance.
57+
""" # noqa: E501
58+
59+
def __init__(
60+
self,
61+
include_rescaling,
62+
include_top,
63+
input_tensor=None,
64+
num_classes=None,
65+
input_shape=(224, 224, 3),
66+
pooling=None,
67+
classifier_activation="softmax",
68+
name="VGG16",
69+
**kwargs,
70+
):
71+
72+
if include_top and num_classes is None:
73+
raise ValueError(
74+
"If `include_top` is True, you should specify `num_classes`. "
75+
f"Received: num_classes={num_classes}"
76+
)
77+
78+
if include_top and pooling:
79+
raise ValueError(
80+
f"`pooling` must be `None` when `include_top=True`."
81+
f"Received pooling={pooling} and include_top={include_top}. "
82+
)
83+
84+
img_input = utils.parse_model_inputs(input_shape, input_tensor)
85+
x = img_input
86+
87+
if include_rescaling:
88+
x = layers.Rescaling(scale=1 / 255.0)(x)
89+
90+
x = apply_vgg_block(
91+
x=x,
92+
num_layers=2,
93+
filters=64,
94+
kernel_size=(3, 3),
95+
activation="relu",
96+
padding="same",
97+
max_pool=True,
98+
name="block1",
99+
)
100+
101+
x = apply_vgg_block(
102+
x=x,
103+
num_layers=2,
104+
filters=128,
105+
kernel_size=(3, 3),
106+
activation="relu",
107+
padding="same",
108+
max_pool=True,
109+
name="block2",
110+
)
111+
112+
x = apply_vgg_block(
113+
x=x,
114+
num_layers=3,
115+
filters=256,
116+
kernel_size=(3, 3),
117+
activation="relu",
118+
padding="same",
119+
max_pool=True,
120+
name="block3",
121+
)
122+
123+
x = apply_vgg_block(
124+
x=x,
125+
num_layers=3,
126+
filters=512,
127+
kernel_size=(3, 3),
128+
activation="relu",
129+
padding="same",
130+
max_pool=True,
131+
name="block4",
132+
)
133+
134+
x = apply_vgg_block(
135+
x=x,
136+
num_layers=3,
137+
filters=512,
138+
kernel_size=(3, 3),
139+
activation="relu",
140+
padding="same",
141+
max_pool=True,
142+
name="block5",
143+
)
144+
145+
if include_top:
146+
x = layers.Flatten(name="flatten")(x)
147+
x = layers.Dense(4096, activation="relu", name="fc1")(x)
148+
x = layers.Dense(4096, activation="relu", name="fc2")(x)
149+
x = layers.Dense(
150+
num_classes,
151+
activation=classifier_activation,
152+
name="predictions",
153+
)(x)
154+
else:
155+
if pooling == "avg":
156+
x = layers.GlobalAveragePooling2D()(x)
157+
elif pooling == "max":
158+
x = layers.GlobalMaxPooling2D()(x)
159+
160+
super().__init__(inputs=img_input, outputs=x, name=name, **kwargs)
161+
162+
self.include_rescaling = include_rescaling
163+
self.include_top = include_top
164+
self.num_classes = num_classes
165+
self.input_tensor = input_tensor
166+
self.pooling = pooling
167+
self.classifier_activation = classifier_activation
168+
169+
def get_config(self):
170+
return {
171+
"include_rescaling": self.include_rescaling,
172+
"include_top": self.include_top,
173+
"name": self.name,
174+
"input_shape": self.input_shape[1:],
175+
"input_tensor": self.input_tensor,
176+
"pooling": self.pooling,
177+
"num_classes": self.num_classes,
178+
"classifier_activation": self.classifier_activation,
179+
"trainable": self.trainable,
180+
}
181+
182+
183+
def apply_vgg_block(
184+
x,
185+
num_layers,
186+
filters,
187+
kernel_size,
188+
activation,
189+
padding,
190+
max_pool,
191+
name,
192+
):
193+
"""
194+
Applies VGG block
195+
Args:
196+
x: Tensor, input tensor to pass through network
197+
num_layers: int, number of CNN layers in the block
198+
filters: int, filter size of each CNN layer in block
199+
kernel_size: int (or) tuple, kernel size for CNN layer in block
200+
activation: str (or) callable, activation function for each CNN layer in
201+
block
202+
padding: str (or) callable, padding function for each CNN layer in block
203+
max_pool: bool, whether to add MaxPooling2D layer at end of block
204+
name: str, name of the block
205+
206+
Returns:
207+
keras.KerasTensor
208+
"""
209+
for num in range(1, num_layers + 1):
210+
x = layers.Conv2D(
211+
filters,
212+
kernel_size,
213+
activation=activation,
214+
padding=padding,
215+
name=f"{name}_conv{num}",
216+
)(x)
217+
if max_pool:
218+
x = layers.MaxPooling2D((2, 2), (2, 2), name=f"{name}_pool")(x)
219+
return x
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2023 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import numpy as np
18+
import pytest
19+
20+
from keras_cv.backend import keras
21+
from keras_cv.models import VGG16Backbone
22+
from keras_cv.tests.test_case import TestCase
23+
24+
25+
class VGG16BackboneTest(TestCase):
26+
def setUp(self):
27+
self.img_input = np.ones((2, 224, 224, 3), dtype="float32")
28+
29+
def test_valid_call(self):
30+
model = VGG16Backbone(
31+
input_shape=(224, 224, 3),
32+
include_top=False,
33+
include_rescaling=False,
34+
pooling="avg",
35+
)
36+
model(self.img_input)
37+
38+
def test_valid_call_with_rescaling(self):
39+
model = VGG16Backbone(
40+
input_shape=(224, 224, 3),
41+
include_top=False,
42+
include_rescaling=True,
43+
pooling="avg",
44+
)
45+
model(self.img_input)
46+
47+
def test_valid_call_with_top(self):
48+
model = VGG16Backbone(
49+
input_shape=(224, 224, 3),
50+
include_top=True,
51+
include_rescaling=False,
52+
num_classes=2,
53+
)
54+
model(self.img_input)
55+
56+
@pytest.mark.large
57+
def test_saved_model(self):
58+
model = VGG16Backbone(
59+
input_shape=(224, 224, 3),
60+
include_top=False,
61+
include_rescaling=False,
62+
num_classes=2,
63+
pooling="avg",
64+
)
65+
model_output = model(self.img_input)
66+
save_path = os.path.join(self.get_temp_dir(), "vgg16.keras")
67+
model.save(save_path)
68+
restored_model = keras.models.load_model(save_path)
69+
70+
# Check the restored model is instance of VGG16Backbone
71+
self.assertIsInstance(restored_model, VGG16Backbone)
72+
73+
# Check if the restored model gives the same output
74+
restored_model_output = restored_model(self.img_input)
75+
self.assertAllClose(model_output, restored_model_output)

0 commit comments

Comments
 (0)