Skip to content

Commit e85f816

Browse files
authored
Make check batch-aware (#200)
* Update inputs.py * Update tensor_typing.py
1 parent 2becc5e commit e85f816

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def hard_validate_atom_indices_ascending(
199199

200200
# NOTE: this is a relaxed assumption, i.e., that if all -1 or only one molecule, then it passes the test
201201

202-
if present_indices.numel() <= 1:
202+
if present_indices.shape[-1] <= 1:
203203
continue
204204

205205
difference = einx.subtract(

alphafold3_pytorch/tensor_typing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Shaped,
1717
jaxtyped
1818
)
19+
from loguru import logger
1920

2021
from torch import Tensor
2122

@@ -65,6 +66,11 @@ def __getitem__(self, shapes: str):
6566

6667
beartype_isinstance = is_bearable if should_typecheck else always(True)
6768

69+
if should_typecheck:
70+
logger.info("Type checking is enabled.")
71+
if IS_DEBUGGING:
72+
logger.info("Debugging is enabled.")
73+
6874
__all__ = [
6975
Shaped,
7076
Float,

0 commit comments

Comments
 (0)