Skip to content

Commit 165d795

Browse files
committed
make attention softmax done in full precision customizable, bfloat16 should be ok
1 parent 35b6bc1 commit 165d795

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

alphafold3_pytorch/attention.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ def __init__(
178178
num_memory_kv: int = 0,
179179
enable_attn_softclamp = False,
180180
attn_softclamp_value = 50.,
181-
init_gate_bias = -2.
181+
init_gate_bias = -2.,
182+
softmax_full_precision = False
182183
):
183184
super().__init__()
184185
"""
@@ -201,6 +202,7 @@ def __init__(
201202
window_size = window_size,
202203
enable_attn_softclamp = enable_attn_softclamp,
203204
attn_softclamp_value = attn_softclamp_value,
205+
softmax_full_precision = softmax_full_precision
204206
)
205207

206208
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
@@ -279,7 +281,8 @@ def __init__(
279281
window_size = None,
280282
scale: float | None = None,
281283
enable_attn_softclamp = False,
282-
attn_softclamp_value = 50.
284+
attn_softclamp_value = 50.,
285+
softmax_full_precision = False
283286
):
284287
super().__init__()
285288
"""
@@ -309,6 +312,9 @@ def __init__(
309312
self.enable_attn_softclamp = enable_attn_softclamp
310313
self.attn_softclamp_value = attn_softclamp_value
311314

315+
# whether to use full precision for softmax
316+
self.softmax_full_precision = softmax_full_precision
317+
312318
@typecheck
313319
def local_attn(
314320
self,
@@ -505,9 +511,16 @@ def forward(
505511
mask, sim, max_neg_value(sim)
506512
)
507513

514+
# attention cast float32 - in case there are instabilities with float16
515+
516+
softmax_kwargs = dict()
517+
518+
if self.softmax_full_precision:
519+
softmax_kwargs.update(dtype = torch.float32)
520+
508521
# attention
509522

510-
attn = sim.softmax(dim = -1, dtype = torch.float32)
523+
attn = sim.softmax(dim = -1, **softmax_kwargs)
511524
attn = attn.to(dtype)
512525

513526
attn = self.attn_dropout(attn)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.5.49"
3+
version = "0.5.50"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)