Skip to content

Commit c15ce86

Browse files
authored
Fix various edge case bugs raised during training-time cropping (#149)
* Update inputs.py * Update biomolecule.py
1 parent 975e11d commit c15ce86

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

alphafold3_pytorch/common/biomolecule.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,8 @@ def spatial_crop(
399399
token_center_atom_mask[self.chain_id == chain_1] = True
400400
elif exists(chain_2):
401401
token_center_atom_mask[self.chain_id == chain_2] = True
402+
else:
403+
raise ValueError("At least one chain ID must be specified for spatial cropping.")
402404

403405
# potentially filter candidate token center atoms by interface proximity
404406

@@ -452,7 +454,16 @@ def crop(
452454
) -> "Biomolecule":
453455
"""Crop a Biomolecule using a randomly-sampled cropping function."""
454456
n_res = min(n_res, len(self.atom_mask))
455-
crop_fn_weights = [contiguous_weight, spatial_weight, spatial_interface_weight]
457+
if exists(chain_1) and exists(chain_2):
458+
crop_fn_weights = [contiguous_weight, spatial_weight, spatial_interface_weight]
459+
elif exists(chain_1) or exists(chain_2):
460+
crop_fn_weights = [contiguous_weight, spatial_weight + spatial_interface_weight, 0.0]
461+
else:
462+
crop_fn_weights = [
463+
contiguous_weight + spatial_weight + spatial_interface_weight,
464+
0.0,
465+
0.0,
466+
]
456467
crop_fns = [
457468
partial(self.contiguous_crop, n_res=n_res),
458469
partial(

alphafold3_pytorch/inputs.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2113,15 +2113,16 @@ def pdb_input_to_molecule_input(
21132113
"""Convert a PDBInput to a MoleculeInput."""
21142114
i = pdb_input
21152115

2116+
filepath = pdb_input.mmcif_filepath
2117+
file_id = os.path.splitext(os.path.basename(filepath))[0]
2118+
21162119
# acquire a `Biomolecule` object for the given `PDBInput`
21172120

21182121
if not exists(biomol) and exists(i.biomol):
21192122
biomol = i.biomol
21202123
else:
21212124
# construct a `Biomolecule` object from the input PDB mmCIF file
21222125

2123-
filepath = pdb_input.mmcif_filepath
2124-
file_id = os.path.splitext(os.path.basename(filepath))[0]
21252126
assert os.path.exists(filepath), f"PDB input file `{filepath}` does not exist."
21262127

21272128
mmcif_object = mmcif_parsing.parse_mmcif_object(
@@ -2162,14 +2163,20 @@ def pdb_input_to_molecule_input(
21622163
assert exists(
21632164
i.cropping_config
21642165
), "A cropping configuration must be provided during training."
2165-
biomol = biomol.crop(
2166-
contiguous_weight=i.cropping_config["contiguous_weight"],
2167-
spatial_weight=i.cropping_config["spatial_weight"],
2168-
spatial_interface_weight=i.cropping_config["spatial_interface_weight"],
2169-
n_res=i.cropping_config["n_res"],
2170-
chain_1=i.chains[0],
2171-
chain_2=i.chains[1],
2172-
)
2166+
try:
2167+
assert exists(i.chains), "Chain IDs must be provided for cropping during training."
2168+
chain_id_1, chain_id_2 = i.chains
2169+
2170+
biomol = biomol.crop(
2171+
contiguous_weight=i.cropping_config["contiguous_weight"],
2172+
spatial_weight=i.cropping_config["spatial_weight"],
2173+
spatial_interface_weight=i.cropping_config["spatial_interface_weight"],
2174+
n_res=i.cropping_config["n_res"],
2175+
chain_1=chain_id_1 if chain_id_1 else None,
2176+
chain_2=chain_id_2 if chain_id_2 else None,
2177+
)
2178+
except Exception as e:
2179+
raise ValueError(f"Failed to crop the biomolecule for input {file_id} due to: {e}")
21732180

21742181
# retrieve features directly available within the `Biomolecule` object
21752182

0 commit comments

Comments
 (0)