Skip to content

Commit 5849d25

Browse files
committed
attn softmax done in float32
1 parent b447339 commit 5849d25

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

alphafold3_pytorch/attention.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,8 @@ def forward(
447447
memory_kv: Float['2 h m d'] | None = None
448448
) -> Float['b h i d']:
449449

450+
dtype = q.dtype
451+
450452
is_windowed_attn_bias = None
451453

452454
if exists(attn_bias):
@@ -506,7 +508,9 @@ def forward(
506508

507509
# attention
508510

509-
attn = sim.softmax(dim = -1)
511+
attn = sim.softmax(dim = -1, dtype = torch.float32)
512+
attn = attn.to(dtype)
513+
510514
attn = self.attn_dropout(attn)
511515

512516
# aggregate values

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.66"
3+
version = "0.2.67"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)