Skip to content

Commit c371712

Browse files
committed
for now just properly default some mask fields that are not passed into AtomInput
1 parent 29c2ab9 commit c371712

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,26 @@ def file_to_atom_input(path: str | Path) -> AtomInput:
245245
atom_input_dict = torch.load(str(path))
246246
return AtomInput(**atom_input_dict)
247247

248+
@typecheck
249+
def default_none_fields_atom_input(i: AtomInput) -> AtomInput:
250+
251+
# if templates given but template mask isn't given, default to all True
252+
253+
if exists(i.templates) and not exists(i.template_mask):
254+
i.template_mask = torch.ones(i.templates.shape[0], dtype = torch.bool)
255+
256+
# if msa given but msa mask isn't given default to all True
257+
258+
if exists(i.msa) and not exists(i.msa_mask):
259+
i.msa_mask = torch.ones(i.msa.shape[0], dtype = torch.bool)
260+
261+
# default missing atom mask should be all False
262+
263+
if not exists(i.missing_atom_mask):
264+
i.missing_atom_mask = torch.zeros(i.atom_inputs.shape[0], dtype = torch.bool)
265+
266+
return i
267+
248268
@typecheck
249269
def pdb_dataset_to_atom_inputs(
250270
pdb_dataset: PDBDataset,
@@ -2698,15 +2718,22 @@ def __getitem__(self, idx: int | str) -> PDBInput:
26982718
# this can be preprocessed or will be taken care of automatically within the Trainer during data collation
26992719

27002720
INPUT_TO_ATOM_TRANSFORM = {
2701-
AtomInput: identity,
2702-
MoleculeInput: molecule_to_atom_input,
2721+
AtomInput: compose(
2722+
default_none_fields_atom_input
2723+
),
2724+
MoleculeInput: compose(
2725+
molecule_to_atom_input,
2726+
default_none_fields_atom_input
2727+
),
27032728
Alphafold3Input: compose(
27042729
alphafold3_input_to_molecule_lengthed_molecule_input,
2705-
molecule_lengthed_molecule_input_to_atom_input
2730+
molecule_lengthed_molecule_input_to_atom_input,
2731+
default_none_fields_atom_input
27062732
),
27072733
PDBInput: compose(
27082734
pdb_input_to_molecule_input,
2709-
molecule_to_atom_input
2735+
molecule_to_atom_input,
2736+
default_none_fields_atom_input
27102737
),
27112738
}
27122739

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.116"
3+
version = "0.2.117"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)