Skip to content

Commit 0d9ced7

Browse files
authored
Add code necessary for overfitting experiments (#151)
* Update mmcif_writing.py * Update trainer.py * Update alphafold3.py * Update inputs.py * Update test_input.py * Create 209d-assembly1.cif * Create 721p-assembly1.cif
1 parent c371712 commit 0d9ced7

File tree

7 files changed

+4177
-19
lines changed

7 files changed

+4177
-19
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4833,6 +4833,7 @@ def forward(
48334833
plddt_labels: Int['b n'] | Int['b m'] | None = None,
48344834
resolved_labels: Int['b n'] | Int['b m'] | None = None,
48354835
chains: Int['b 2'] | None = None,
4836+
filepath: List[str] | None = None,
48364837
return_loss_breakdown = False,
48374838
return_loss: bool = None,
48384839
return_present_sampled_atoms: bool = False,

alphafold3_pytorch/data/mmcif_writing.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,15 @@
1212
from alphafold3_pytorch.data.mmcif_parsing import MmcifObject, parse_mmcif_object
1313
from alphafold3_pytorch.utils.utils import exists
1414

15+
1516
def write_mmcif_from_filepath_and_id(
16-
filepath: str,
17-
file_id: str,
18-
suffix: str = 'sampled',
19-
**kwargs
17+
input_filepath: str, output_filepath: str, file_id: str, **kwargs
2018
):
21-
mmcif_object = parse_mmcif_object(
22-
filepath = filepath,
23-
file_id = file_id
24-
)
19+
"""Write an input mmCIF file to an output mmCIF filepath using the provided keyword arguments
20+
(e.g., sampled coordinates)."""
21+
mmcif_object = parse_mmcif_object(filepath=input_filepath, file_id=file_id)
22+
return write_mmcif(mmcif_object, output_filepath=output_filepath, **kwargs)
2523

26-
output_filepath = filepath.replace(".cif", f"-{suffix}.cif")
27-
28-
return write_mmcif(
29-
mmcif_object,
30-
output_filepath = output_filepath,
31-
**kwargs
32-
)
3324

3425
def write_mmcif(
3526
mmcif_object: MmcifObject,

alphafold3_pytorch/inputs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ class AtomInput:
178178
plddt_labels: Int[' n'] | None = None
179179
resolved_labels: Int[' n'] | None = None
180180
chains: Int[" 2"] | None = None
181+
filepath: str | None = None
181182

182183
def dict(self):
183184
return asdict(self)
@@ -211,6 +212,7 @@ class BatchedAtomInput:
211212
plddt_labels: Int['b n'] | None = None
212213
resolved_labels: Int['b n'] | None = None
213214
chains: Int["b 2"] | None = None
215+
filepath: List[str] | None = None
214216

215217
def dict(self):
216218
return asdict(self)
@@ -432,6 +434,7 @@ class MoleculeInput:
432434
pde_labels: Int[' n'] | None = None
433435
resolved_labels: Int[' n'] | None = None
434436
chains: Tuple[int | None, int | None] | None = (None, None)
437+
filepath: str | None = None
435438
add_atom_ids: bool = False
436439
add_atompair_ids: bool = False
437440
directed_bonds: bool = False
@@ -712,6 +715,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
712715
atom_ids=atom_ids,
713716
atompair_ids=atompair_ids,
714717
chains=chains,
718+
filepath=i.filepath,
715719
)
716720

717721
return atom_input
@@ -749,6 +753,7 @@ class MoleculeLengthMoleculeInput:
749753
pde_labels: Int[' n'] | None = None
750754
resolved_labels: Int[' n'] | None = None
751755
chains: Tuple[int | None, int | None] | None = (None, None)
756+
filepath: str | None = None
752757
add_atom_ids: bool = False
753758
add_atompair_ids: bool = False
754759
directed_bonds: bool = False
@@ -1135,6 +1140,7 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
11351140
atom_ids = atom_ids,
11361141
atompair_ids = atompair_ids,
11371142
chains = chains,
1143+
filepath=i.filepath,
11381144
)
11391145

11401146
return atom_input
@@ -2602,6 +2608,7 @@ def pdb_input_to_molecule_input(
26022608
template_mask=template_mask,
26032609
msa_mask=msa_mask,
26042610
chains=chains,
2611+
filepath=filepath,
26052612
add_atom_ids=i.add_atom_ids,
26062613
add_atompair_ids=i.add_atompair_ids,
26072614
directed_bonds=i.directed_bonds,

alphafold3_pytorch/trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,12 @@ def collate_inputs_to_batched_atom_input(
142142

143143
# separate input dictionary into keys and values
144144

145-
keys = atom_inputs[0].dict().keys()
145+
keys = list(atom_inputs[0].dict().keys())
146146
atom_inputs = [i.dict().values() for i in atom_inputs]
147147

148148
outputs = []
149149

150-
for grouped in zip(*atom_inputs):
150+
for group_index, grouped in enumerate(zip(*atom_inputs)):
151151
# if all None, just return None
152152

153153
not_none_grouped = [*filter(exists, grouped)]
@@ -156,6 +156,12 @@ def collate_inputs_to_batched_atom_input(
156156
outputs.append(None)
157157
continue
158158

159+
# collate list of input filepath strings
160+
161+
if keys[group_index] == "filepath":
162+
outputs.append(not_none_grouped)
163+
continue
164+
159165
# default to empty tensor for any Nones
160166

161167
one_tensor = not_none_grouped[0]

0 commit comments

Comments
 (0)