Skip to content

Commit 6388cfc

Browse files
committed
turn off mixed precision in weighted rigid align due to svd
1 parent b3c2b73 commit 6388cfc

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from torch import nn
1212
from torch import Tensor
13+
from torch.amp import autocast
1314
import torch.nn.functional as F
1415
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
1516

@@ -2979,6 +2980,7 @@ class WeightedRigidAlign(Module):
29792980
""" Algorithm 28 """
29802981

29812982
@typecheck
2983+
@autocast('cuda', enabled = False)
29822984
def forward(
29832985
self,
29842986
pred_coords: Float['b n 3'], # predicted coordinates

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.4.5"
3+
version = "0.4.6"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)