@@ -99,18 +99,17 @@ def pad_to_multiple(
9999 return pad_at_dim (t , (0 , padding_needed ), dim = dim , value = value )
100100
101101@typecheck
102- def concat_neighboring_windows (
102+ def concat_previous_window (
103103 t : Tensor ,
104104 * ,
105105 dim_seq : int ,
106106 dim_window : int
107107):
108- t = pad_at_dim (t , (1 , 1 ), dim = dim_seq , value = 0. )
108+ t = pad_at_dim (t , (1 , 0 ), dim = dim_seq , value = 0. )
109109
110110 t = torch .cat ((
111- slice_at_dim (t , slice (None , - 2 ), dim = dim_seq ),
112- slice_at_dim (t , slice (1 , - 1 ), dim = dim_seq ),
113- slice_at_dim (t , slice (2 , None ), dim = dim_seq )
111+ slice_at_dim (t , slice (None , - 1 ), dim = dim_seq ),
112+ slice_at_dim (t , slice (1 , None ), dim = dim_seq ),
114113 ), dim = dim_window )
115114
116115 return t
@@ -121,14 +120,14 @@ def concat_neighboring_windows(
121120def full_pairwise_repr_to_windowed (
122121 pairwise_repr : Float ['... m m dp' ],
123122 window_size : int
124- ) -> Float ['... n w (w*3 ) dp' ]:
123+ ) -> Float ['... n w (w*2 ) dp' ]:
125124
126125 seq_len , device = pairwise_repr .shape [- 2 ], pairwise_repr .device
127126
128127 padding_needed = (window_size - (seq_len % window_size )) % window_size
129128 pairwise_repr = F .pad (pairwise_repr , (0 , 0 , 0 , padding_needed , 0 , padding_needed ), value = 0. )
130129 pairwise_repr = rearrange (pairwise_repr , '... (i w1) (j w2) d -> ... i j w1 w2 d' , w1 = window_size , w2 = window_size )
131- pairwise_repr = concat_neighboring_windows (pairwise_repr , dim_seq = - 4 , dim_window = - 2 )
130+ pairwise_repr = concat_previous_window (pairwise_repr , dim_seq = - 4 , dim_window = - 2 )
132131
133132 # get the diagonal
134133
@@ -145,7 +144,7 @@ def full_pairwise_repr_to_windowed(
145144def full_attn_bias_to_windowed (
146145 attn_bias : Float ['... m m' ],
147146 window_size : int
148- ) -> Float ['... n w (w*3 )' ]:
147+ ) -> Float ['... n w (w*2 )' ]:
149148
150149 attn_bias = rearrange (attn_bias , '... -> ... 1' )
151150 attn_bias = full_pairwise_repr_to_windowed (attn_bias , window_size = window_size )
@@ -215,7 +214,7 @@ def forward(
215214 seq : Float ['b i d' ],
216215 mask : Bool ['b n' ]| None = None ,
217216 context : Float ['b j d' ] | None = None ,
218- attn_bias : Float ['... i j' ] | Float ['... nw w (w*3 )' ] | None = None
217+ attn_bias : Float ['... i j' ] | Float ['... nw w (w*2 )' ] | None = None
219218
220219 ) -> Float ['b i d' ]:
221220
@@ -316,7 +315,7 @@ def local_attn(
316315 k : Float ['b h n d' ],
317316 v : Float ['b h n d' ],
318317 mask : Bool ['b n' ] | None = None ,
319- attn_bias : Float ['... n n' ] | Float ['... nw w (w*3 )' ] | None = None
318+ attn_bias : Float ['... n n' ] | Float ['... nw w (w*2 )' ] | None = None
320319 ) -> Float ['b h n d' ]:
321320 """
322321 simple local attention with a radius of 1 window size
@@ -345,11 +344,11 @@ def local_attn(
345344 # just do radius of 1 for now
346345 # perhaps not even necessary, and could try shifted windows (a la Swin)
347346
348- k , v = tuple (pad_at_dim (t , (1 , 1 ), dim = - 2 ) for t in (k , v ))
349- mask = F .pad (mask , (1 , 1 ), value = False )
347+ k , v = tuple (pad_at_dim (t , (1 , 0 ), dim = - 2 ) for t in (k , v ))
348+ mask = F .pad (mask , (1 , 0 ), value = False )
350349
351- k , v = tuple (torch .cat ((t [..., :- 2 , :], t [..., 1 : - 1 , :], t [..., 2 :, :]), dim = - 2 ) for t in (k , v ))
352- mask = torch .cat ((mask [..., :- 2 ], mask [..., 1 : - 1 ], mask [..., 2 :]), dim = - 1 )
350+ k , v = tuple (torch .cat ((t [..., :- 1 , :], t [..., 1 :, :]), dim = - 2 ) for t in (k , v ))
351+ mask = torch .cat ((mask [..., :- 1 ], mask [..., 1 :]), dim = - 1 )
353352
354353 # handle attention bias (inefficiently)
355354
@@ -399,7 +398,7 @@ def forward(
399398 k : Float ['b h j d' ],
400399 v : Float ['b h j d' ],
401400 mask : Bool ['b j' ] | None = None ,
402- attn_bias : Float ['... i j' ] | Float ['... nw w (w*3 )' ] | None = None ,
401+ attn_bias : Float ['... i j' ] | Float ['... nw w (w*2 )' ] | None = None ,
403402 ) -> Float ['b h i d' ]:
404403
405404 is_windowed_attn_bias = None
0 commit comments