Skip to content

Commit 191a804

Browse files
authored
make vit compatible with non square images (#2255)
* make vit compatible with non square images * fix converter issue * update presets * use std for scale looping not mean * patch size can also be int dtype
1 parent c314f88 commit 191a804

File tree

7 files changed

+89
-117
lines changed

7 files changed

+89
-117
lines changed

keras_hub/src/models/vit/vit_backbone.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ class ViTBackbone(Backbone):
1818
1919
Args:
2020
image_shape: A tuple or list of 3 integers representing the shape of the
21-
input image `(height, width, channels)`, `height` and `width` must
22-
be equal.
23-
patch_size: int. The size of each image patch, the input image will be
24-
divided into patches of shape `(patch_size, patch_size)`.
21+
input image `(height, width, channels)`.
22+
patch_size: int or (int, int). The size of each image patch, the input
23+
image will be divided into patches of shape
24+
`(patch_size_h, patch_size_w)`.
2525
num_layers: int. The number of transformer encoder layers.
2626
num_heads: int. specifying the number of attention heads in each
2727
Transformer encoder layer.
@@ -37,6 +37,10 @@ class ViTBackbone(Backbone):
3737
use_mha_bias: bool. Whether to use bias in the multi-head
3838
attention layers.
3939
use_mlp_bias: bool. Whether to use bias in the MLP layers.
40+
use_class_token: bool. Whether to use class token to be part of
41+
patch embedding. Defaults to `True`.
42+
use_patch_bias: bool. Whether to use bias in Conv2d of patch embedding
43+
layer. Defaults to `True`.
4044
data_format: str. `"channels_last"` or `"channels_first"`, specifying
4145
the data format for the input image. If `None`, defaults to
4246
`"channels_last"`.
@@ -58,6 +62,8 @@ def __init__(
5862
layer_norm_epsilon=1e-6,
5963
use_mha_bias=True,
6064
use_mlp_bias=True,
65+
use_class_token=True,
66+
use_patch_bias=True,
6167
data_format=None,
6268
dtype=None,
6369
**kwargs,
@@ -74,24 +80,34 @@ def __init__(
7480
f"at index {h_axis} (height) or {w_axis} (width). "
7581
f"Image shape: {image_shape}"
7682
)
77-
if image_shape[h_axis] != image_shape[w_axis]:
83+
84+
if isinstance(patch_size, int):
85+
patch_size = (patch_size, patch_size)
86+
87+
if image_shape[h_axis] % patch_size[0] != 0:
88+
raise ValueError(
89+
f"Input height {image_shape[h_axis]} should be divisible by "
90+
f"patch size {patch_size[0]}."
91+
)
92+
93+
if image_shape[w_axis] % patch_size[1] != 0:
7894
raise ValueError(
79-
f"Image height and width must be equal. Found height: "
80-
f"{image_shape[h_axis]}, width: {image_shape[w_axis]} at "
81-
f"indices {h_axis} and {w_axis} respectively. Image shape: "
82-
f"{image_shape}"
95+
f"Input width {image_shape[h_axis]} should be divisible by "
96+
f"patch size {patch_size[1]}."
8397
)
8498

8599
num_channels = image_shape[channels_axis]
86100

87101
# === Functional Model ===
88-
inputs = keras.layers.Input(shape=image_shape)
102+
inputs = keras.layers.Input(shape=image_shape, name="images")
89103

90104
x = ViTPatchingAndEmbedding(
91-
image_size=image_shape[h_axis],
105+
image_size=(image_shape[h_axis], image_shape[w_axis]),
92106
patch_size=patch_size,
93107
hidden_dim=hidden_dim,
94108
num_channels=num_channels,
109+
use_class_token=use_class_token,
110+
use_patch_bias=use_patch_bias,
95111
data_format=data_format,
96112
dtype=dtype,
97113
name="vit_patching_and_embedding",
@@ -130,6 +146,8 @@ def __init__(
130146
self.layer_norm_epsilon = layer_norm_epsilon
131147
self.use_mha_bias = use_mha_bias
132148
self.use_mlp_bias = use_mlp_bias
149+
self.use_class_token = use_class_token
150+
self.use_patch_bias = use_patch_bias
133151
self.data_format = data_format
134152

135153
def get_config(self):
@@ -147,6 +165,8 @@ def get_config(self):
147165
"layer_norm_epsilon": self.layer_norm_epsilon,
148166
"use_mha_bias": self.use_mha_bias,
149167
"use_mlp_bias": self.use_mlp_bias,
168+
"use_class_token": self.use_class_token,
169+
"use_patch_bias": self.use_patch_bias,
150170
}
151171
)
152172
return config

keras_hub/src/models/vit/vit_backbone_test.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class ViTBackboneTest(TestCase):
99
def setUp(self):
1010
self.init_kwargs = {
1111
"image_shape": (28, 28, 3),
12-
"patch_size": 4,
12+
"patch_size": (4, 4),
1313
"num_layers": 3,
1414
"hidden_dim": 48,
1515
"num_heads": 6,
@@ -25,7 +25,15 @@ def test_backbone_basics(self):
2525
init_kwargs={**self.init_kwargs},
2626
input_data=self.input_data,
2727
expected_output_shape=(2, 50, 48),
28-
run_quantization_check=False,
28+
)
29+
30+
def test_backbone_basics_without_class_token(self):
31+
self.init_kwargs["use_class_token"] = False
32+
self.run_backbone_test(
33+
cls=ViTBackbone,
34+
init_kwargs={**self.init_kwargs},
35+
input_data=self.input_data,
36+
expected_output_shape=(2, 49, 48),
2937
)
3038

3139
@pytest.mark.large

keras_hub/src/models/vit/vit_image_classifier_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def setUp(self):
1616
self.labels = [0, 1]
1717
self.backbone = ViTBackbone(
1818
image_shape=(28, 28, 3),
19-
patch_size=4,
19+
patch_size=(4, 4),
2020
num_layers=3,
2121
num_heads=6,
2222
hidden_dim=48,
Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,8 @@
11
from keras_hub.src.api_export import keras_hub_export
22
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
33
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
4-
from keras_hub.src.utils.tensor_utils import preprocessing_function
54

65

76
@keras_hub_export("keras_hub.layers.ViTImageConverter")
87
class ViTImageConverter(ImageConverter):
9-
"""Converts images to the format expected by a ViT model.
10-
11-
This layer performs image normalization using mean and standard deviation
12-
values. By default, it uses the same normalization as the
13-
"google/vit-large-patch16-224" model on Hugging Face:
14-
`norm_mean=[0.5, 0.5, 0.5]` and `norm_std=[0.5, 0.5, 0.5]`
15-
([reference](https://huggingface.co/google/vit-large-patch16-224/blob/main/preprocessor_config.json)).
16-
These defaults are suitable for models pretrained using this normalization.
17-
18-
Args:
19-
norm_mean: list or tuple of floats. Mean values for image normalization.
20-
Defaults to `[0.5, 0.5, 0.5]`.
21-
norm_std: list or tuple of floats. Standard deviation values for
22-
image normalization. Defaults to `[0.5, 0.5, 0.5]`.
23-
**kwargs: Additional keyword arguments passed to
24-
`keras_hub.layers.preprocessing.ImageConverter`.
25-
26-
Examples:
27-
```python
28-
import keras
29-
import numpy as np
30-
from keras_hub.src.layers import ViTImageConverter
31-
32-
# Example image (replace with your actual image data)
33-
image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C)
34-
35-
# Create a ViTImageConverter instance
36-
converter = ViTImageConverter(
37-
image_size=(28,28),
38-
scale=1/255.
39-
)
40-
# Preprocess the image
41-
preprocessed_image = converter(image)
42-
```
43-
"""
44-
458
backbone_cls = ViTBackbone
46-
47-
def __init__(
48-
self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs
49-
):
50-
super().__init__(**kwargs)
51-
self.norm_mean = norm_mean
52-
self.norm_std = norm_std
53-
54-
@preprocessing_function
55-
def call(self, inputs):
56-
# TODO: Remove this whole function. Why can just use scale and offset
57-
# in the base class.
58-
x = super().call(inputs)
59-
if self.norm_mean:
60-
norm_mean = self._expand_non_channel_dims(self.norm_mean, x)
61-
x, norm_mean = self._convert_types(x, norm_mean, self.compute_dtype)
62-
x = x - norm_mean
63-
if self.norm_std:
64-
norm_std = self._expand_non_channel_dims(self.norm_std, x)
65-
x, norm_std = self._convert_types(x, norm_std, x.dtype)
66-
x = x / norm_std
67-
68-
return x
69-
70-
def get_config(self):
71-
config = super().get_config()
72-
config.update(
73-
{
74-
"norm_mean": self.norm_mean,
75-
"norm_std": self.norm_std,
76-
}
77-
)
78-
return config

keras_hub/src/models/vit/vit_layers.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,13 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
7575
"""Patches the image and embeds the patches.
7676
7777
Args:
78-
image_size: int. Size of the input image (height or width).
79-
Assumed to be square.
80-
patch_size: int. Size of each image patch.
78+
image_size: (int, int). Size of the input image.
79+
patch_size: (int, int). Size of each image patch.
8180
hidden_dim: int. Dimensionality of the patch embeddings.
8281
num_channels: int. Number of channels in the input image. Defaults to
8382
`3`.
83+
use_class_token: bool. Whether to use class token to be part of
84+
patch embedding. Defaults to `True`.
8485
data_format: str. `"channels_last"` or `"channels_first"`. Defaults to
8586
`None` (which uses `"channels_last"`).
8687
**kwargs: Additional keyword arguments passed to `keras.layers.Layer`
@@ -92,12 +93,15 @@ def __init__(
9293
patch_size,
9394
hidden_dim,
9495
num_channels=3,
96+
use_class_token=True,
97+
use_patch_bias=True,
9598
data_format=None,
9699
**kwargs,
97100
):
98101
super().__init__(**kwargs)
99-
num_patches = (image_size // patch_size) ** 2
100-
num_positions = num_patches + 1
102+
grid_size = tuple([s // p for s, p in zip(image_size, patch_size)])
103+
num_patches = grid_size[0] * grid_size[1]
104+
num_positions = num_patches + 1 if use_class_token else num_patches
101105

102106
# === Config ===
103107
self.image_size = image_size
@@ -106,19 +110,22 @@ def __init__(
106110
self.num_channels = num_channels
107111
self.num_patches = num_patches
108112
self.num_positions = num_positions
113+
self.use_class_token = use_class_token
114+
self.use_patch_bias = use_patch_bias
109115
self.data_format = standardize_data_format(data_format)
110116

111117
def build(self, input_shape):
112-
self.class_token = self.add_weight(
113-
shape=(
114-
1,
115-
1,
116-
self.hidden_dim,
117-
),
118-
initializer="random_normal",
119-
dtype=self.variable_dtype,
120-
name="class_token",
121-
)
118+
if self.use_class_token:
119+
self.class_token = self.add_weight(
120+
shape=(
121+
1,
122+
1,
123+
self.hidden_dim,
124+
),
125+
initializer="random_normal",
126+
dtype=self.variable_dtype,
127+
name="class_token",
128+
)
122129
self.patch_embedding = keras.layers.Conv2D(
123130
filters=self.hidden_dim,
124131
kernel_size=self.patch_size,
@@ -127,6 +134,7 @@ def build(self, input_shape):
127134
activation=None,
128135
dtype=self.dtype_policy,
129136
data_format=self.data_format,
137+
use_bias=self.use_patch_bias,
130138
name="patch_embedding",
131139
)
132140
self.patch_embedding.build(input_shape)
@@ -153,10 +161,16 @@ def call(self, inputs):
153161
patch_embeddings = ops.reshape(
154162
patch_embeddings, [embeddings_shape[0], -1, embeddings_shape[-1]]
155163
)
156-
class_token = ops.tile(self.class_token, (embeddings_shape[0], 1, 1))
157164
position_embeddings = self.position_embedding(self.position_ids)
158-
embeddings = ops.concatenate([class_token, patch_embeddings], axis=1)
159-
return ops.add(embeddings, position_embeddings)
165+
166+
if self.use_class_token:
167+
class_token = ops.tile(
168+
self.class_token, (embeddings_shape[0], 1, 1)
169+
)
170+
patch_embeddings = ops.concatenate(
171+
[class_token, patch_embeddings], axis=1
172+
)
173+
return ops.add(patch_embeddings, position_embeddings)
160174

161175
def compute_output_shape(self, input_shape):
162176
return (
@@ -175,6 +189,7 @@ def get_config(self):
175189
"num_channels": self.num_channels,
176190
"num_patches": self.num_patches,
177191
"num_positions": self.num_positions,
192+
"use_class_token": self.use_class_token,
178193
}
179194
)
180195
return config

0 commit comments

Comments
 (0)