@@ -433,18 +433,14 @@ def __init__(
433433 dim_hidden = default (dim_hidden , dim )
434434 self .norm = nn .LayerNorm (dim )
435435
436- self .left_proj = Linear (dim , dim_hidden )
437- self .right_proj = Linear (dim , dim_hidden )
438-
439- self .left_gate = Linear (dim , dim_hidden )
440- self .right_gate = Linear (dim , dim_hidden )
441- self .out_gate = Linear (dim , dim_hidden )
436+ self .left_right_proj = nn .Sequential (
437+ LinearNoBias (dim , dim_hidden * 4 ),
438+ nn .GLU (dim = - 1 )
439+ )
442440
443- # initialize all gating to be identity
441+ self . left_right_gate = LinearNoBias ( dim , dim_hidden * 2 )
444442
445- for gate in (self .left_gate , self .right_gate , self .out_gate ):
446- nn .init .constant_ (gate .weight , 0. )
447- nn .init .constant_ (gate .bias , 1. )
443+ self .out_gate = LinearNoBias (dim , dim_hidden )
448444
449445 if mix == 'outgoing' :
450446 self .mix_einsum_eq = '... i k d, ... j k d -> ... i j d'
@@ -454,7 +450,7 @@ def __init__(
454450 self .to_out_norm = nn .LayerNorm (dim_hidden )
455451
456452 self .to_out = Sequential (
457- Linear (dim_hidden , dim ),
453+ LinearNoBias (dim_hidden , dim ),
458454 Dropout (dropout , dropout_type = dropout_type )
459455 )
460456
@@ -470,24 +466,19 @@ def forward(
470466
471467 x = self .norm (x )
472468
473- left = self .left_proj (x )
474- right = self .right_proj (x )
469+ left , right = self .left_right_proj (x ).chunk (2 , dim = - 1 )
475470
476471 if exists (mask ):
477472 left = left * mask
478473 right = right * mask
479474
480- left_gate = self .left_gate (x ).sigmoid ()
481- right_gate = self .right_gate (x ).sigmoid ()
482- out_gate = self .out_gate (x ).sigmoid ()
483-
484- left = left * left_gate
485- right = right * right_gate
486-
487475 out = einsum (left , right , self .mix_einsum_eq )
488476
489477 out = self .to_out_norm (out )
478+
479+ out_gate = self .out_gate (x ).sigmoid ()
490480 out = out * out_gate
481+
491482 return self .to_out (out )
492483
493484# there are two types of attention in this paper, triangle and attention-pair-bias
0 commit comments