@@ -119,9 +119,11 @@ def __init__(
119119 dim ,
120120 dim_head = 32 ,
121121 dropout = 0. ,
122- window_size = 7
122+ window_size = 7 ,
123+ num_registers = 1
123124 ):
124125 super ().__init__ ()
126+ assert num_registers > 0
125127 assert (dim % dim_head ) == 0 , 'dimension should be divisible by dimension per head'
126128
127129 self .heads = dim // dim_head
@@ -142,7 +144,9 @@ def __init__(
142144
143145 # relative positional bias
144146
145- self .rel_pos_bias = nn .Embedding ((2 * window_size - 1 ) ** 2 , self .heads )
147+ num_rel_pos_bias = (2 * window_size - 1 ) ** 2
148+
149+ self .rel_pos_bias = nn .Embedding (num_rel_pos_bias + 1 , self .heads )
146150
147151 pos = torch .arange (window_size )
148152 grid = torch .stack (torch .meshgrid (pos , pos , indexing = 'ij' ))
@@ -151,10 +155,11 @@ def __init__(
151155 rel_pos += window_size - 1
152156 rel_pos_indices = (rel_pos * torch .tensor ([2 * window_size - 1 , 1 ])).sum (dim = - 1 )
153157
158+ rel_pos_indices = F .pad (rel_pos_indices , (num_registers , 0 , num_registers , 0 ), value = num_rel_pos_bias )
154159 self .register_buffer ('rel_pos_indices' , rel_pos_indices , persistent = False )
155160
156161 def forward (self , x ):
157- device , h = x .device , self .heads
162+ device , h , bias_indices = x .device , self .heads , self . rel_pos_indices
158163
159164 x = self .norm (x )
160165
@@ -176,13 +181,8 @@ def forward(self, x):
176181
177182 # add positional bias
178183
179- bias = self .rel_pos_bias (self .rel_pos_indices )
180- bias = rearrange (bias , 'i j h -> h i j' )
181-
182- num_registers = sim .shape [- 1 ] - bias .shape [- 1 ]
183- bias = F .pad (bias , (num_registers , 0 , num_registers , 0 ), value = 0. )
184-
185- sim = sim + bias
184+ bias = self .rel_pos_bias (bias_indices )
185+ sim = sim + rearrange (bias , 'i j h -> h i j' )
186186
187187 # attention
188188
@@ -215,6 +215,7 @@ def __init__(
215215 ):
216216 super ().__init__ ()
217217 assert isinstance (depth , tuple ), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
218+ assert num_register_tokens > 0
218219
219220 # convolutional stem
220221
@@ -256,10 +257,10 @@ def __init__(
256257 shrinkage_rate = mbconv_shrinkage_rate
257258 )
258259
259- block_attn = Attention (dim = layer_dim , dim_head = dim_head , dropout = dropout , window_size = window_size )
260+ block_attn = Attention (dim = layer_dim , dim_head = dim_head , dropout = dropout , window_size = window_size , num_registers = num_register_tokens )
260261 block_ff = FeedForward (dim = layer_dim , dropout = dropout )
261262
262- grid_attn = Attention (dim = layer_dim , dim_head = dim_head , dropout = dropout , window_size = window_size )
263+ grid_attn = Attention (dim = layer_dim , dim_head = dim_head , dropout = dropout , window_size = window_size , num_registers = num_register_tokens )
263264 grid_ff = FeedForward (dim = layer_dim , dropout = dropout )
264265
265266 register_tokens = nn .Parameter (torch .randn (num_register_tokens , layer_dim ))
0 commit comments