@@ -449,12 +449,12 @@ def forward(
449449 return self .dropout (t )
450450
451451 if self .dropout_type == 'row' :
452- batch , row , _ , _ = t .shape
453- ones_shape = (batch , row , 1 , 1 )
452+ batch , _ , col , dim = t .shape
453+ ones_shape = (batch , 1 , col , dim )
454454
455455 elif self .dropout_type == 'col' :
456- batch , _ , col , _ = t .shape
457- ones_shape = (batch , 1 , col , 1 )
456+ batch , row , _ , dim = t .shape
457+ ones_shape = (batch , row , 1 , dim )
458458
459459 ones = t .new_ones (ones_shape )
460460 dropped = self .dropout (ones )
@@ -624,9 +624,8 @@ def forward(
624624 out = self .to_out_norm (out )
625625
626626 out_gate = self .out_gate (x ).sigmoid ()
627- out = out * out_gate
628627
629- return self .to_out (out )
628+ return self .to_out (out ) * out_gate
630629
631630# there are two types of attention in this paper, triangle and attention-pair-bias
632631# they differ by how the attention bias is computed
@@ -1316,8 +1315,8 @@ def __init__(
13161315 # final projection of mean pooled repr -> out
13171316
13181317 self .to_out = nn .Sequential (
1319- LinearNoBias ( dim , dim_pairwise ),
1320- nn . ReLU ( )
1318+ nn . ReLU ( ),
1319+ LinearNoBias ( dim , dim_pairwise )
13211320 )
13221321
13231322 self .layerscale = nn .Parameter (torch .zeros (dim_pairwise )) if layerscale_output else 1.
0 commit comments