File tree Expand file tree Collapse file tree 3 files changed +6
-4
lines changed Expand file tree Collapse file tree 3 files changed +6
-4
lines changed Original file line number Diff line number Diff line change 99from torch import nn , sigmoid
1010from torch import Tensor
1111import torch .nn .functional as F
12+ from loguru import logger
1213
1314from torch .nn import (
1415 Module ,
@@ -2380,7 +2381,7 @@ def forward(
23802381 true_coords_centered = true_coords - true_centroid
23812382
23822383 if num_points < (dim + 1 ):
2383- print (
2384+ logger . warning (
23842385 "Warning: The size of one of the point clouds is <= dim+1. "
23852386 + "`WeightedRigidAlign` cannot return a unique rotation."
23862387 )
@@ -2393,7 +2394,7 @@ def forward(
23932394
23942395 # Catch ambiguous rotation by checking the magnitude of singular values
23952396 if (S .abs () <= 1e-15 ).any () and not (num_points < (dim + 1 )):
2396- print (
2397+ logger . warning (
23972398 "Warning: Excessively low rank of "
23982399 + "cross-correlation between aligned point clouds. "
23992400 + "`WeightedRigidAlign` cannot return a unique rotation."
@@ -3182,7 +3183,7 @@ def load(
31823183 current_version = version ('alphafold3_pytorch' )
31833184
31843185 if model_package ['version' ] != current_version :
3185- print (f'loading a saved model from version { model_package ["version" ]} but you are on version { current_version } ' )
3186+ logger . warning (f'loading a saved model from version { model_package ["version" ]} but you are on version { current_version } ' )
31863187
31873188 self .load_state_dict (model_package ['state_dict' ], strict = strict )
31883189
Original file line number Diff line number Diff line change @@ -183,7 +183,7 @@ def collate_af3_inputs(
183183
184184 # reconstitute dictionary
185185
186- batched_atom_inputs = AtomInput (tuple (zip (keys , outputs )))
186+ batched_atom_inputs = BatchedAtomInput (tuple (zip (keys , outputs )))
187187 return batched_atom_inputs
188188
189189@typecheck
Original file line number Diff line number Diff line change @@ -45,6 +45,7 @@ dependencies = [
4545 ' torch_geometric' ,
4646 " torch>=2.1" ,
4747 " tqdm>=4.66.4" ,
48+ " loguru" ,
4849]
4950
5051[project .urls ]
You can’t perform that action at this time.
0 commit comments