@@ -245,6 +245,26 @@ def file_to_atom_input(path: str | Path) -> AtomInput:
245245 atom_input_dict = torch .load (str (path ))
246246 return AtomInput (** atom_input_dict )
247247
248+ @typecheck
249+ def default_none_fields_atom_input (i : AtomInput ) -> AtomInput :
250+
251+ # if templates given but template mask isn't given, default to all True
252+
253+ if exists (i .templates ) and not exists (i .template_mask ):
254+ i .template_mask = torch .ones (i .templates .shape [0 ], dtype = torch .bool )
255+
256+ # if msa given but msa mask isn't given default to all True
257+
258+ if exists (i .msa ) and not exists (i .msa_mask ):
259+ i .msa_mask = torch .ones (i .msa .shape [0 ], dtype = torch .bool )
260+
261+ # default missing atom mask should be all False
262+
263+ if not exists (i .missing_atom_mask ):
264+ i .missing_atom_mask = torch .zeros (i .atom_inputs .shape [0 ], dtype = torch .bool )
265+
266+ return i
267+
248268@typecheck
249269def pdb_dataset_to_atom_inputs (
250270 pdb_dataset : PDBDataset ,
@@ -2698,15 +2718,22 @@ def __getitem__(self, idx: int | str) -> PDBInput:
26982718# this can be preprocessed or will be taken care of automatically within the Trainer during data collation
26992719
27002720INPUT_TO_ATOM_TRANSFORM = {
2701- AtomInput : identity ,
2702- MoleculeInput : molecule_to_atom_input ,
2721+ AtomInput : compose (
2722+ default_none_fields_atom_input
2723+ ),
2724+ MoleculeInput : compose (
2725+ molecule_to_atom_input ,
2726+ default_none_fields_atom_input
2727+ ),
27032728 Alphafold3Input : compose (
27042729 alphafold3_input_to_molecule_lengthed_molecule_input ,
2705- molecule_lengthed_molecule_input_to_atom_input
2730+ molecule_lengthed_molecule_input_to_atom_input ,
2731+ default_none_fields_atom_input
27062732 ),
27072733 PDBInput : compose (
27082734 pdb_input_to_molecule_input ,
2709- molecule_to_atom_input
2735+ molecule_to_atom_input ,
2736+ default_none_fields_atom_input
27102737 ),
27112738}
27122739
0 commit comments