@@ -13,22 +13,13 @@ def exists(val):
1313def default (val , d ):
1414 return val if exists (val ) else d
1515
16- # pre-layernorm
17-
18- class PreNorm (nn .Module ):
19- def __init__ (self , dim , fn ):
20- super ().__init__ ()
21- self .norm = nn .LayerNorm (dim )
22- self .fn = fn
23- def forward (self , x , ** kwargs ):
24- return self .fn (self .norm (x ), ** kwargs )
25-
2616# feedforward
2717
2818class FeedForward (nn .Module ):
2919 def __init__ (self , dim , hidden_dim , dropout = 0. ):
3020 super ().__init__ ()
3121 self .net = nn .Sequential (
22+ nn .LayerNorm (dim ),
3223 nn .Linear (dim , hidden_dim ),
3324 nn .GELU (),
3425 nn .Dropout (dropout ),
@@ -47,6 +38,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
4738 self .heads = heads
4839 self .scale = dim_head ** - 0.5
4940
41+ self .norm = nn .LayerNorm (dim )
5042 self .attend = nn .Softmax (dim = - 1 )
5143 self .dropout = nn .Dropout (dropout )
5244
@@ -60,6 +52,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
6052
6153 def forward (self , x , context = None , kv_include_self = False ):
6254 b , n , _ , h = * x .shape , self .heads
55+ x = self .norm (x )
6356 context = default (context , x )
6457
6558 if kv_include_self :
@@ -86,8 +79,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
8679 self .norm = nn .LayerNorm (dim )
8780 for _ in range (depth ):
8881 self .layers .append (nn .ModuleList ([
89- PreNorm ( dim , Attention (dim , heads = heads , dim_head = dim_head , dropout = dropout ) ),
90- PreNorm ( dim , FeedForward (dim , mlp_dim , dropout = dropout ) )
82+ Attention (dim , heads = heads , dim_head = dim_head , dropout = dropout ),
83+ FeedForward (dim , mlp_dim , dropout = dropout )
9184 ]))
9285
9386 def forward (self , x ):
@@ -121,8 +114,8 @@ def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout):
121114 self .layers = nn .ModuleList ([])
122115 for _ in range (depth ):
123116 self .layers .append (nn .ModuleList ([
124- ProjectInOut (sm_dim , lg_dim , PreNorm ( lg_dim , Attention (lg_dim , heads = heads , dim_head = dim_head , dropout = dropout ) )),
125- ProjectInOut (lg_dim , sm_dim , PreNorm (sm_dim , Attention ( sm_dim , heads = heads , dim_head = dim_head , dropout = dropout ) ))
117+ ProjectInOut (sm_dim , lg_dim , Attention (lg_dim , heads = heads , dim_head = dim_head , dropout = dropout )),
118+ ProjectInOut (lg_dim , sm_dim , ttention (sm_dim , heads = heads , dim_head = dim_head , dropout = dropout ))
126119 ]))
127120
128121 def forward (self , sm_tokens , lg_tokens ):
0 commit comments