@@ -81,6 +81,15 @@ def __init__(self, dim):
8181 def forward (self , x ):
8282 return F .layer_norm (x , x .shape [- 1 :], self .gamma , self .beta )
8383
84+ class MultiHeadedRMSNorm (nn .Module ):
85+ def __init__ (self , dim , heads = 1 ):
86+ super ().__init__ ()
87+ self .scale = dim ** 0.5
88+ self .gamma = nn .Parameter (torch .ones (heads , 1 , dim ))
89+
90+ def forward (self , x ):
91+ return F .normalize (x , dim = - 1 ) * self .scale * self .gamma
92+
8493# positional embeds
8594
8695class LearnedSinusoidalPosEmb (nn .Module ):
@@ -104,6 +113,7 @@ def __init__(
104113 heads = 4 ,
105114 dim_head = 32 ,
106115 norm = False ,
116+ qk_norm = False ,
107117 time_cond_dim = None
108118 ):
109119 super ().__init__ ()
@@ -127,6 +137,11 @@ def __init__(
127137
128138 self .to_qkv = nn .Linear (dim , hidden_dim * 3 , bias = False )
129139
140+ self .qk_norm = qk_norm
141+ if qk_norm :
142+ self .q_norm = MultiHeadedRMSNorm (dim_head , heads )
143+ self .k_norm = MultiHeadedRMSNorm (dim_head , heads )
144+
130145 self .to_out = nn .Sequential (
131146 nn .Linear (hidden_dim , dim , bias = False ),
132147 LayerNorm (dim )
@@ -148,6 +163,10 @@ def forward(
148163 qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
149164 q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
150165
166+ if self .qk_norm :
167+ q = self .q_norm (q )
168+ k = self .k_norm (k )
169+
151170 q = q .softmax (dim = - 1 )
152171 k = k .softmax (dim = - 2 )
153172
@@ -169,7 +188,8 @@ def __init__(
169188 norm = False ,
170189 norm_context = False ,
171190 time_cond_dim = None ,
172- flash = False
191+ flash = False ,
192+ qk_norm = False
173193 ):
174194 super ().__init__ ()
175195 hidden_dim = dim_head * heads
@@ -197,6 +217,11 @@ def __init__(
197217 self .to_kv = nn .Linear (dim_context , hidden_dim * 2 , bias = False )
198218 self .to_out = nn .Linear (hidden_dim , dim , bias = False )
199219
220+ self .qk_norm = qk_norm
221+ if qk_norm :
222+ self .q_norm = MultiHeadedRMSNorm (dim_head , heads )
223+ self .k_norm = MultiHeadedRMSNorm (dim_head , heads )
224+
200225 self .attend = Attend (flash = flash )
201226
202227 def forward (
@@ -222,6 +247,10 @@ def forward(
222247 qkv = (self .to_q (x ), * self .to_kv (context ).chunk (2 , dim = - 1 ))
223248 q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
224249
250+ if self .qk_norm :
251+ q = self .q_norm (q )
252+ k = self .k_norm (k )
253+
225254 out = self .attend (q , k , v )
226255
227256 out = rearrange (out , 'b h n d -> b n (h d)' )
0 commit comments