Skip to content

Commit 050ef0f

Browse files
authored
[fix] collate_af3_inputs function should return an instence of BatchedAtomInput (#48)
* [fix] collate_af3_inputs function should return an instence of BatchedAtomInput * use logger to log info
1 parent 11094c4 commit 050ef0f

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch import nn, sigmoid
1010
from torch import Tensor
1111
import torch.nn.functional as F
12+
from loguru import logger
1213

1314
from torch.nn import (
1415
Module,
@@ -2380,7 +2381,7 @@ def forward(
23802381
true_coords_centered = true_coords - true_centroid
23812382

23822383
if num_points < (dim + 1):
2383-
print(
2384+
logger.warning(
23842385
"Warning: The size of one of the point clouds is <= dim+1. "
23852386
+ "`WeightedRigidAlign` cannot return a unique rotation."
23862387
)
@@ -2393,7 +2394,7 @@ def forward(
23932394

23942395
# Catch ambiguous rotation by checking the magnitude of singular values
23952396
if (S.abs() <= 1e-15).any() and not (num_points < (dim + 1)):
2396-
print(
2397+
logger.warning(
23972398
"Warning: Excessively low rank of "
23982399
+ "cross-correlation between aligned point clouds. "
23992400
+ "`WeightedRigidAlign` cannot return a unique rotation."
@@ -3182,7 +3183,7 @@ def load(
31823183
current_version = version('alphafold3_pytorch')
31833184

31843185
if model_package['version'] != current_version:
3185-
print(f'loading a saved model from version {model_package["version"]} but you are on version {current_version}')
3186+
logger.warning(f'loading a saved model from version {model_package["version"]} but you are on version {current_version}')
31863187

31873188
self.load_state_dict(model_package['state_dict'], strict = strict)
31883189

alphafold3_pytorch/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def collate_af3_inputs(
183183

184184
# reconstitute dictionary
185185

186-
batched_atom_inputs = AtomInput(tuple(zip(keys, outputs)))
186+
batched_atom_inputs = BatchedAtomInput(tuple(zip(keys, outputs)))
187187
return batched_atom_inputs
188188

189189
@typecheck

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ dependencies = [
4545
'torch_geometric',
4646
"torch>=2.1",
4747
"tqdm>=4.66.4",
48+
"loguru",
4849
]
4950

5051
[project.urls]

0 commit comments

Comments
 (0)