Skip to content

Commit fca359d

Browse files
committed
account if some sample within a batch does not have any msas or templates
1 parent cbd32f3 commit fca359d

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,11 @@ def forward(
780780
if exists(msa_mask):
781781
msa_mask = einx.get_at('b [s], b sampled -> b sampled', msa_mask, indices)
782782

783+
# account for no msa
784+
785+
if exists(msa_mask):
786+
has_msa = reduce(msa_mask, 'b s -> b', 'any')
787+
783788
# process msa
784789

785790
msa = self.msa_init_proj(msa)
@@ -806,6 +811,12 @@ def forward(
806811

807812
pairwise_repr = pairwise_block(pairwise_repr = pairwise_repr, mask = mask)
808813

814+
if exists(msa_mask):
815+
pairwise_repr = einx.where(
816+
'b, b ..., -> b ...',
817+
has_msa, pairwise_repr, 0.
818+
)
819+
809820
return pairwise_repr
810821

811822
# pairformer stack
@@ -1033,6 +1044,8 @@ def forward(
10331044

10341045
v, merged_batch_ps = pack_one(v, '* i j d')
10351046

1047+
has_templates = reduce(template_mask, 'b t -> b', 'any')
1048+
10361049
if exists(mask):
10371050
mask = repeat(mask, 'b n -> (b t) n', t = num_templates)
10381051

@@ -1058,7 +1071,14 @@ def forward(
10581071

10591072
avg_template_repr = einx.divide('b i j d, b -> b i j d', num, den.clamp(min = self.eps))
10601073

1061-
return self.to_out(avg_template_repr)
1074+
out = self.to_out(avg_template_repr)
1075+
1076+
out = einx.where(
1077+
'b, b ..., -> b ...',
1078+
has_templates, out, 0.
1079+
)
1080+
1081+
return out
10621082

10631083
# diffusion related
10641084
# both diffusion transformer as well as atom encoder / decoder

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.20"
3+
version = "0.0.21"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)