@@ -58,7 +58,7 @@ def window_unpartition(
5858 x: Unpartitioned tensor of shape (B, H, W, C).
5959 """
6060 hp , wp = pad_hw
61- h , w = hw
61+ h , w = hw [ 0 ], hw [ 1 ]
6262 window_size = tf .shape (windows )[1 ]
6363 nb_windows = (hp // window_size ) * (wp // window_size )
6464 n = tf .shape (windows )[0 ] // nb_windows
@@ -93,7 +93,7 @@ def get_rel_pos(
9393 Extracted positional embeddings according to relative positions.
9494 """
9595 m = tf .shape (rel_pos )[0 ]
96- max_rel_dist = int (2 * max (q_size , k_size ) - 1 )
96+ max_rel_dist = tf . cast (2 * tf . math . maximum (q_size , k_size ) - 1 , tf . int32 )
9797
9898 if interpolate_pos :
9999 # Interpolate positional embeddings if needed.
@@ -108,10 +108,10 @@ def get_rel_pos(
108108 q_coords = tf .expand_dims (tf .range (q_size , dtype = tf .float32 ), axis = - 1 )
109109 k_coords = tf .expand_dims (tf .range (k_size , dtype = tf .float32 ), axis = 0 )
110110 # Scale the coords with short length if shapes for q and k are different.
111- q_coords = q_coords * tf .cast (max (k_size / q_size , 1.0 ), tf .float32 )
112- k_coords = k_coords * tf .cast (max (q_size / k_size , 1.0 ), tf .float32 )
111+ q_coords = q_coords * tf .cast (tf . math . maximum (k_size / q_size , 1.0 ), tf .float32 )
112+ k_coords = k_coords * tf .cast (tf . math . maximum (q_size / k_size , 1.0 ), tf .float32 )
113113
114- lambda_ = tf .cast (max (q_size / k_size , 1.0 ), tf .float32 )
114+ lambda_ = tf .cast (tf . math . maximum (q_size / k_size , 1.0 ), tf .float32 )
115115 offset = tf .cast (k_size - 1 , tf .float32 ) * lambda_
116116 relative_coords = (q_coords - k_coords ) + offset
117117 relative_coords = tf .cast (relative_coords , tf .int32 )
@@ -168,7 +168,7 @@ def add_decomposed_rel_pos(
168168 return attn
169169
170170
171- class Attention (tf .keras .layers .Layer ):
171+ class RelPosAttention (tf .keras .layers .Layer ):
172172 """Multi-head Attention block with relative position embeddings."""
173173
174174 def __init__ (
@@ -263,7 +263,7 @@ def call(self, x, training=False):
263263 return x
264264
265265
266- class Block (tf .keras .layers .Layer ):
266+ class ImageEncoderBlock (tf .keras .layers .Layer ):
267267 """
268268 Transformer blocks with support for window attention and residual propagation.
269269 """
@@ -316,7 +316,7 @@ def __init__(
316316 norm_layer = norm_layer_factory (norm_layer )
317317
318318 self .norm1 = norm_layer (name = "norm1" )
319- self .attn = Attention (
319+ self .attn = RelPosAttention (
320320 fixed_input_size = self .fixed_input_size ,
321321 embed_dim = self .embed_dim ,
322322 nb_heads = self .nb_heads ,
@@ -438,7 +438,7 @@ def __init__(
438438 self .pos_embed = None
439439
440440 self .blocks = [
441- Block (
441+ ImageEncoderBlock (
442442 fixed_input_size = self .fixed_input_size ,
443443 embed_dim = self .embed_dim ,
444444 nb_heads = self .nb_heads ,
0 commit comments