Skip to content

Commit 0cee1cc

Browse files
committed
use softclamping to address numerical issues with laser instead
1 parent 34f6c59 commit 0cee1cc

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

alphafold3_pytorch/attention.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def __init__(
185185
window_size = None,
186186
num_memory_kv: int = 0,
187187
laser = False,
188+
laser_softclamp_value = 15.,
188189
enable_attn_softclamp = False,
189190
attn_softclamp_value = 50.,
190191
softmax_full_precision = False
@@ -206,10 +207,11 @@ def __init__(
206207
dim_inner = dim_head * heads
207208

208209
self.attend = Attend(
209-
laser = laser,
210210
dropout = dropout,
211211
window_size = window_size,
212212
enable_attn_softclamp = enable_attn_softclamp,
213+
laser = laser,
214+
laser_softclamp_value = laser_softclamp_value,
213215
attn_softclamp_value = attn_softclamp_value,
214216
softmax_full_precision = softmax_full_precision
215217
)
@@ -305,6 +307,7 @@ def __init__(
305307
self,
306308
dropout = 0.,
307309
laser = False,
310+
laser_softclamp_value = 15.,
308311
window_size = None,
309312
scale: float | None = None,
310313
enable_attn_softclamp = False,
@@ -336,6 +339,7 @@ def __init__(
336339
# laser attention
337340

338341
self.laser = laser
342+
self.laser_softclamp_value = laser_softclamp_value
339343

340344
# softclamp attention logits
341345
# being adopted by a number of recent llms (gemma, grok)
@@ -460,8 +464,7 @@ def local_attn(
460464
# maybe laser
461465

462466
if self.laser:
463-
v_max = v.amax(dim = -2, keepdim = True)
464-
v = (v - v_max).exp()
467+
v = softclamp(v, self.laser_softclamp_value)
465468

466469
# aggregate
467470

@@ -470,7 +473,7 @@ def local_attn(
470473
# maybe laser
471474

472475
if self.laser:
473-
out = log(out) + v_max
476+
out = log(out)
474477

475478
# un-window the output
476479

0 commit comments

Comments
 (0)