@@ -178,7 +178,8 @@ def __init__(
178178 num_memory_kv : int = 0 ,
179179 enable_attn_softclamp = False ,
180180 attn_softclamp_value = 50. ,
181- init_gate_bias = - 2.
181+ init_gate_bias = - 2. ,
182+ softmax_full_precision = False
182183 ):
183184 super ().__init__ ()
184185 """
@@ -201,6 +202,7 @@ def __init__(
201202 window_size = window_size ,
202203 enable_attn_softclamp = enable_attn_softclamp ,
203204 attn_softclamp_value = attn_softclamp_value ,
205+ softmax_full_precision = softmax_full_precision
204206 )
205207
206208 self .split_heads = Rearrange ('b n (h d) -> b h n d' , h = heads )
@@ -279,7 +281,8 @@ def __init__(
279281 window_size = None ,
280282 scale : float | None = None ,
281283 enable_attn_softclamp = False ,
282- attn_softclamp_value = 50.
284+ attn_softclamp_value = 50. ,
285+ softmax_full_precision = False
283286 ):
284287 super ().__init__ ()
285288 """
@@ -309,6 +312,9 @@ def __init__(
309312 self .enable_attn_softclamp = enable_attn_softclamp
310313 self .attn_softclamp_value = attn_softclamp_value
311314
315+ # whether to use full precision for softmax
316+ self .softmax_full_precision = softmax_full_precision
317+
312318 @typecheck
313319 def local_attn (
314320 self ,
@@ -505,9 +511,16 @@ def forward(
505511 mask , sim , max_neg_value (sim )
506512 )
507513
514+ # attention cast float32 - in case there are instabilities with float16
515+
516+ softmax_kwargs = dict ()
517+
518+ if self .softmax_full_precision :
519+ softmax_kwargs .update (dtype = torch .float32 )
520+
508521 # attention
509522
510- attn = sim .softmax (dim = - 1 , dtype = torch . float32 )
523+ attn = sim .softmax (dim = - 1 , ** softmax_kwargs )
511524 attn = attn .to (dtype )
512525
513526 attn = self .attn_dropout (attn )
0 commit comments