@@ -75,12 +75,13 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
75
75
"""Patches the image and embeds the patches.
76
76
77
77
Args:
78
- image_size: int. Size of the input image (height or width).
79
- Assumed to be square.
80
- patch_size: int. Size of each image patch.
78
+ image_size: (int, int). Size of the input image.
79
+ patch_size: (int, int). Size of each image patch.
81
80
hidden_dim: int. Dimensionality of the patch embeddings.
82
81
num_channels: int. Number of channels in the input image. Defaults to
83
82
`3`.
83
+ use_class_token: bool. Whether to use class token to be part of
84
+ patch embedding. Defaults to `True`.
84
85
data_format: str. `"channels_last"` or `"channels_first"`. Defaults to
85
86
`None` (which uses `"channels_last"`).
86
87
**kwargs: Additional keyword arguments passed to `keras.layers.Layer`
@@ -92,12 +93,15 @@ def __init__(
92
93
patch_size ,
93
94
hidden_dim ,
94
95
num_channels = 3 ,
96
+ use_class_token = True ,
97
+ use_patch_bias = True ,
95
98
data_format = None ,
96
99
** kwargs ,
97
100
):
98
101
super ().__init__ (** kwargs )
99
- num_patches = (image_size // patch_size ) ** 2
100
- num_positions = num_patches + 1
102
+ grid_size = tuple ([s // p for s , p in zip (image_size , patch_size )])
103
+ num_patches = grid_size [0 ] * grid_size [1 ]
104
+ num_positions = num_patches + 1 if use_class_token else num_patches
101
105
102
106
# === Config ===
103
107
self .image_size = image_size
@@ -106,19 +110,22 @@ def __init__(
106
110
self .num_channels = num_channels
107
111
self .num_patches = num_patches
108
112
self .num_positions = num_positions
113
+ self .use_class_token = use_class_token
114
+ self .use_patch_bias = use_patch_bias
109
115
self .data_format = standardize_data_format (data_format )
110
116
111
117
def build (self , input_shape ):
112
- self .class_token = self .add_weight (
113
- shape = (
114
- 1 ,
115
- 1 ,
116
- self .hidden_dim ,
117
- ),
118
- initializer = "random_normal" ,
119
- dtype = self .variable_dtype ,
120
- name = "class_token" ,
121
- )
118
+ if self .use_class_token :
119
+ self .class_token = self .add_weight (
120
+ shape = (
121
+ 1 ,
122
+ 1 ,
123
+ self .hidden_dim ,
124
+ ),
125
+ initializer = "random_normal" ,
126
+ dtype = self .variable_dtype ,
127
+ name = "class_token" ,
128
+ )
122
129
self .patch_embedding = keras .layers .Conv2D (
123
130
filters = self .hidden_dim ,
124
131
kernel_size = self .patch_size ,
@@ -127,6 +134,7 @@ def build(self, input_shape):
127
134
activation = None ,
128
135
dtype = self .dtype_policy ,
129
136
data_format = self .data_format ,
137
+ use_bias = self .use_patch_bias ,
130
138
name = "patch_embedding" ,
131
139
)
132
140
self .patch_embedding .build (input_shape )
@@ -153,10 +161,16 @@ def call(self, inputs):
153
161
patch_embeddings = ops .reshape (
154
162
patch_embeddings , [embeddings_shape [0 ], - 1 , embeddings_shape [- 1 ]]
155
163
)
156
- class_token = ops .tile (self .class_token , (embeddings_shape [0 ], 1 , 1 ))
157
164
position_embeddings = self .position_embedding (self .position_ids )
158
- embeddings = ops .concatenate ([class_token , patch_embeddings ], axis = 1 )
159
- return ops .add (embeddings , position_embeddings )
165
+
166
+ if self .use_class_token :
167
+ class_token = ops .tile (
168
+ self .class_token , (embeddings_shape [0 ], 1 , 1 )
169
+ )
170
+ patch_embeddings = ops .concatenate (
171
+ [class_token , patch_embeddings ], axis = 1
172
+ )
173
+ return ops .add (patch_embeddings , position_embeddings )
160
174
161
175
def compute_output_shape (self , input_shape ):
162
176
return (
@@ -175,6 +189,7 @@ def get_config(self):
175
189
"num_channels" : self .num_channels ,
176
190
"num_patches" : self .num_patches ,
177
191
"num_positions" : self .num_positions ,
192
+ "use_class_token" : self .use_class_token ,
178
193
}
179
194
)
180
195
return config
0 commit comments