Skip to content

Commit cffd6a0

Browse files
committed
jaxtyping unable to help here, so just do asserts on shape manually
1 parent ff1daf1 commit cffd6a0

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ def log(t, eps = 1e-20):
138138
def divisible_by(num, den):
139139
return (num % den) == 0
140140

141+
def compact(*args):
142+
return tuple(filter(exists, args))
143+
141144
# tensor helpers
142145

143146
def max_neg_value(t: Tensor):
@@ -3905,6 +3908,9 @@ def forward(
39053908

39063909
# cross entropy losses
39073910

3911+
assert all([t.shape[-1] for t in compact(pde_labels, plddt_labels, resolved_labels)])
3912+
assert pde_labels.shape[-1] == ch_logits.pde.shape[-1]
3913+
39083914
if exists(pae_labels):
39093915
pae_labels = torch.where(label_pairwise_mask, pae_labels, ignore)
39103916
pae_loss = F.cross_entropy(ch_logits.pae, pae_labels, ignore_index = ignore)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.12"
3+
version = "0.2.14"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)