Skip to content

Commit 59360db

Browse files
authored
Ensure template embedder and resolved logits always contribute to the loss (#259)
* Update alphafold3.py * Update inputs.py * Update alphafold3.py * Update inputs.py * Update inputs.py * Update alphafold3.py
1 parent dd76d25 commit 59360db

File tree

2 files changed

+62
-50
lines changed

2 files changed

+62
-50
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5957,6 +5957,7 @@ def __init__(
59575957
# store atom and atompair input dimensions for shape validation
59585958

59595959
self.dim_atom_inputs = dim_atom_inputs
5960+
self.dim_template_feats = dim_template_feats
59605961
self.dim_atompair_inputs = dim_atompair_inputs
59615962

59625963
# optional atom and atom bond embeddings
@@ -6631,15 +6632,24 @@ def forward(
66316632

66326633
# templates
66336634

6634-
if exists(templates):
6635-
embedded_template = self.template_embedder(
6636-
templates = templates,
6637-
template_mask = template_mask,
6638-
pairwise_repr = pairwise,
6639-
mask = mask
6635+
if not exists(templates):
6636+
templates = torch.zeros(
6637+
(batch_size, 1, seq_len, seq_len, self.dim_template_feats),
6638+
dtype=dtype,
6639+
device=self.device,
66406640
)
6641+
template_mask = torch.zeros((batch_size, 1), dtype=torch.bool, device=self.device)
66416642

6642-
pairwise = embedded_template + pairwise
6643+
# ensure template embedder always contributes to the loss
6644+
6645+
embedded_template = self.template_embedder(
6646+
templates=templates,
6647+
template_mask=template_mask,
6648+
pairwise_repr=pairwise,
6649+
mask=mask,
6650+
)
6651+
6652+
pairwise = embedded_template + pairwise
66436653

66446654
# msa
66456655

@@ -7271,6 +7281,10 @@ def cross_entropy_with_weight(
72717281
f"ch_logits.resolved shape {ch_logits.resolved.shape[-1]}"
72727282
)
72737283
resolved_loss = cross_entropy_with_weight(ch_logits.resolved, resolved_labels, confidence_weight, label_mask, ignore)
7284+
else:
7285+
resolved_loss = (
7286+
ch_logits.resolved * 0.0
7287+
).mean() # ensure resolved logits always contribute to the loss
72747288

72757289
confidence_loss = pae_loss + pde_loss + plddt_loss + resolved_loss
72767290

alphafold3_pytorch/inputs.py

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2485,7 +2485,7 @@ def load_msa_from_msa_dir(
24852485
"""Load MSA from a directory containing MSA files."""
24862486
if verbose and (not_exists(msa_dir) or not os.path.exists(msa_dir)):
24872487
logger.warning(
2488-
f"{msa_dir} does not exist. Dummy MSA features for each chain of file {file_id} will instead be loaded."
2488+
f"{msa_dir} MSA directory does not exist. Dummy MSA features for each chain of file {file_id} will instead be loaded."
24892489
)
24902490

24912491
msas = {}
@@ -2600,35 +2600,32 @@ def load_templates_from_templates_dir(
26002600
kalign_binary_path: str | None = None,
26012601
template_cutoff_date: datetime | None = None,
26022602
randomly_sample_num_templates: bool = False,
2603-
raise_missing_exception: bool = False,
26042603
verbose: bool = False,
26052604
) -> FeatureDict:
26062605
"""Load templates from a directory containing template PDB mmCIF files."""
2607-
if (
2608-
not_exists(templates_dir) or not os.path.exists(templates_dir)
2609-
) and raise_missing_exception:
2610-
raise FileNotFoundError(f"{templates_dir} does not exist.")
2611-
elif not_exists(templates_dir) or not os.path.exists(templates_dir):
2612-
if verbose:
2613-
logger.warning(
2614-
f"{templates_dir} does not exist. Skipping template loading by returning `Nones`."
2615-
)
2616-
return {}
2606+
if verbose and (not_exists(templates_dir) or not os.path.exists(templates_dir)):
2607+
logger.warning(
2608+
f"{templates_dir} templates directory does not exist. Dummy template features for each chain of file {file_id} will instead be loaded."
2609+
)
26172610

2618-
if (not_exists(mmcif_dir) or not os.path.exists(mmcif_dir)) and raise_missing_exception:
2619-
raise FileNotFoundError(f"{mmcif_dir} does not exist.")
2620-
elif not_exists(mmcif_dir) or not os.path.exists(mmcif_dir):
2621-
if verbose:
2622-
logger.warning(
2623-
f"{mmcif_dir} does not exist. Skipping template loading by returning `Nones`."
2624-
)
2625-
return {}
2611+
if verbose and (not_exists(mmcif_dir) or not os.path.exists(mmcif_dir)):
2612+
logger.warning(
2613+
f"{mmcif_dir} mmCIF templates directory does not exist. Dummy template features for each chain of file {file_id} will instead be loaded."
2614+
)
26262615

26272616
templates = defaultdict(list)
26282617
for chain_id in chain_id_to_residue:
2629-
template_fpaths = glob.glob(os.path.join(templates_dir, f"{file_id}{chain_id}_*.m8"))
2618+
template_fpaths = (
2619+
glob.glob(os.path.join(templates_dir, f"{file_id}{chain_id}_*.m8"))
2620+
if exists(templates_dir)
2621+
else []
2622+
)
26302623

26312624
if not template_fpaths:
2625+
if verbose:
2626+
logger.warning(
2627+
f"Could not find template for chain {chain_id} of file {file_id}. A dummy template will be installed for this chain."
2628+
)
26322629
templates[chain_id] = []
26332630
continue
26342631

@@ -2808,30 +2805,31 @@ def pdb_input_to_molecule_input(
28082805
exists(feat)
28092806
for feat in [msa, msa_row_mask, has_deletion, deletion_value, profile, deletion_mean]
28102807
)
2811-
if all_msa_features_exist:
2812-
assert (
2813-
msa.shape[-1] == num_tokens
2814-
), f"The number of tokens in the MSA ({msa.shape[-1]}) does not match the number of tokens in the biomolecule ({num_tokens}). "
2815-
2816-
additional_msa_feats = torch.stack(
2817-
[
2818-
has_deletion,
2819-
deletion_value,
2820-
],
2821-
dim=-1,
2822-
)
28232808

2824-
additional_token_feats = torch.cat(
2825-
[
2826-
profile,
2827-
deletion_mean[:, None],
2828-
],
2829-
dim=-1,
2830-
)
2809+
assert all_msa_features_exist, "All MSA features must be derived for each example."
2810+
assert (
2811+
msa.shape[-1] == num_tokens
2812+
), f"The number of tokens in the MSA ({msa.shape[-1]}) does not match the number of tokens in the biomolecule ({num_tokens}). "
2813+
2814+
additional_msa_feats = torch.stack(
2815+
[
2816+
has_deletion,
2817+
deletion_value,
2818+
],
2819+
dim=-1,
2820+
)
2821+
2822+
additional_token_feats = torch.cat(
2823+
[
2824+
profile,
2825+
deletion_mean[:, None],
2826+
],
2827+
dim=-1,
2828+
)
28312829

2832-
# convert the MSA into a one-hot representation
2833-
msa = make_one_hot(msa, NUM_MSA_ONE_HOT)
2834-
msa_row_mask = msa_row_mask.bool()
2830+
# convert the MSA into a one-hot representation
2831+
msa = make_one_hot(msa, NUM_MSA_ONE_HOT)
2832+
msa_row_mask = msa_row_mask.bool()
28352833

28362834
# retrieve templates for each chain
28372835

0 commit comments

Comments
 (0)