We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b447339 commit 5849d25Copy full SHA for 5849d25
alphafold3_pytorch/attention.py
@@ -447,6 +447,8 @@ def forward(
447
memory_kv: Float['2 h m d'] | None = None
448
) -> Float['b h i d']:
449
450
+ dtype = q.dtype
451
+
452
is_windowed_attn_bias = None
453
454
if exists(attn_bias):
@@ -506,7 +508,9 @@ def forward(
506
508
507
509
# attention
510
- attn = sim.softmax(dim = -1)
511
+ attn = sim.softmax(dim = -1, dtype = torch.float32)
512
+ attn = attn.to(dtype)
513
514
attn = self.attn_dropout(attn)
515
516
# aggregate values
pyproject.toml
@@ -1,6 +1,6 @@
1
[project]
2
name = "alphafold3-pytorch"
3
-version = "0.2.66"
+version = "0.2.67"
4
description = "Alphafold 3 - Pytorch"
5
authors = [
6
{ name = "Phil Wang", email = "[email protected]" }
0 commit comments