@@ -185,6 +185,7 @@ def __init__(
185185 window_size = None ,
186186 num_memory_kv : int = 0 ,
187187 laser = False ,
188+ laser_softclamp_value = 15. ,
188189 enable_attn_softclamp = False ,
189190 attn_softclamp_value = 50. ,
190191 softmax_full_precision = False
@@ -206,10 +207,11 @@ def __init__(
206207 dim_inner = dim_head * heads
207208
208209 self .attend = Attend (
209- laser = laser ,
210210 dropout = dropout ,
211211 window_size = window_size ,
212212 enable_attn_softclamp = enable_attn_softclamp ,
213+ laser = laser ,
214+ laser_softclamp_value = laser_softclamp_value ,
213215 attn_softclamp_value = attn_softclamp_value ,
214216 softmax_full_precision = softmax_full_precision
215217 )
@@ -305,6 +307,7 @@ def __init__(
305307 self ,
306308 dropout = 0. ,
307309 laser = False ,
310+ laser_softclamp_value = 15. ,
308311 window_size = None ,
309312 scale : float | None = None ,
310313 enable_attn_softclamp = False ,
@@ -336,6 +339,7 @@ def __init__(
336339 # laser attention
337340
338341 self .laser = laser
342+ self .laser_softclamp_value = laser_softclamp_value
339343
340344 # softclamp attention logits
341345 # being adopted by a number of recent llms (gemma, grok)
@@ -460,8 +464,7 @@ def local_attn(
460464 # maybe laser
461465
462466 if self .laser :
463- v_max = v .amax (dim = - 2 , keepdim = True )
464- v = (v - v_max ).exp ()
467+ v = softclamp (v , self .laser_softclamp_value )
465468
466469 # aggregate
467470
@@ -470,7 +473,7 @@ def local_attn(
470473 # maybe laser
471474
472475 if self .laser :
473- out = log (out ) + v_max
476+ out = log (out )
474477
475478 # un-window the output
476479
0 commit comments