Skip to content

Commit 9c68b6d

Browse files
committed
remove unpacked atom representation, to reduce complexity, and because eventual ligand fine tuning makes no sense with that rep
1 parent 73ae179 commit 9c68b6d

File tree

5 files changed

+94
-224
lines changed

5 files changed

+94
-224
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ alphafold3 = Alphafold3(
3636
# mock inputs
3737

3838
seq_len = 16
39-
atom_seq_len = seq_len * 27
39+
residue_atom_lens = torch.randint(1, 3, (2, seq_len))
40+
atom_seq_len = residue_atom_lens.sum(dim = -1).amax()
4041

4142
atom_inputs = torch.randn(2, atom_seq_len, 77)
4243
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
4344

44-
atom_lens = torch.randint(0, 27, (2, seq_len))
4545
additional_residue_feats = torch.randn(2, seq_len, 10)
4646

4747
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
@@ -53,7 +53,7 @@ msa_mask = torch.ones((2, 7)).bool()
5353
# required for training, but omitted on inference
5454

5555
atom_pos = torch.randn(2, atom_seq_len, 3)
56-
residue_atom_indices = torch.randint(0, 27, (2, seq_len))
56+
residue_atom_indices = residue_atom_lens - 1 # last atom, as an example
5757

5858
distance_labels = torch.randint(0, 37, (2, seq_len, seq_len))
5959
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len))
@@ -67,7 +67,7 @@ loss = alphafold3(
6767
num_recycling_steps = 2,
6868
atom_inputs = atom_inputs,
6969
atompair_inputs = atompair_inputs,
70-
residue_atom_lens = atom_lens,
70+
residue_atom_lens = residue_atom_lens,
7171
additional_residue_feats = additional_residue_feats,
7272
msa = msa,
7373
msa_mask = msa_mask,
@@ -91,15 +91,15 @@ sampled_atom_pos = alphafold3(
9191
num_sample_steps = 16,
9292
atom_inputs = atom_inputs,
9393
atompair_inputs = atompair_inputs,
94-
residue_atom_lens = atom_lens,
94+
residue_atom_lens = residue_atom_lens,
9595
additional_residue_feats = additional_residue_feats,
9696
msa = msa,
9797
msa_mask = msa_mask,
9898
templates = template_feats,
9999
template_mask = template_mask
100100
)
101101

102-
sampled_atom_pos.shape # (2, 16 * 27, 3)
102+
sampled_atom_pos.shape # (2, <atom_seqlen>, 3)
103103
```
104104

105105
## Contributing

0 commit comments

Comments
 (0)