Skip to content

Commit fdfe72b

Browse files
committed
completely wire up the diffusion module, and show end to end in readme
1 parent db0edca commit fdfe72b

File tree

3 files changed

+124
-7
lines changed

3 files changed

+124
-7
lines changed

README.md

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ template_mask = torch.ones((2, 2)).bool()
4343

4444
msa = torch.randn(2, 7, seq_len, 64)
4545

46+
# required for training, but omitted on inference
47+
48+
atom_pos = torch.randn(2, atom_seq_len, 3)
49+
distance_labels = torch.randint(0, 37, (2, seq_len, seq_len))
50+
4651
# train
4752

4853
loss = alphafold3(
@@ -53,11 +58,28 @@ loss = alphafold3(
5358
additional_residue_feats = additional_residue_feats,
5459
msa = msa,
5560
templates = template_feats,
56-
template_mask = template_mask
61+
template_mask = template_mask,
62+
atom_pos = atom_pos,
63+
distance_labels = distance_labels
5764
)
5865

5966
loss.backward()
6067

68+
# after much training ...
69+
70+
sampled_atom_pos = alphafold3(
71+
num_recycling_steps = 4,
72+
num_sample_steps = 16,
73+
atom_inputs = atom_inputs,
74+
atom_mask = atom_mask,
75+
atompair_feats = atompair_feats,
76+
additional_residue_feats = additional_residue_feats,
77+
msa = msa,
78+
templates = template_feats,
79+
template_mask = template_mask
80+
)
81+
82+
sampled_atom_pos.shape # (2, 16 * 27, 3)
6183
```
6284

6385
## Citations
@@ -83,7 +105,7 @@ loss.backward()
83105
{\v Z}{\'\i}dek, Augustin and Bapst, Victor and Kohli, Pushmeet
84106
and Jaderberg, Max and Hassabis, Demis and Jumper, John M",
85107
journal = "Nature",
86-
month = may,
108+
month = "May",
87109
year = 2024
88110
}
89111
```

alphafold3_pytorch/alphafold3.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1982,12 +1982,14 @@ def __init__(
19821982
dim_input_embedder_token = 384,
19831983
dim_single = 384,
19841984
dim_pairwise = 128,
1985+
dim_token = 768,
19851986
atompair_dist_bins: Float[' dist_bins'] = torch.linspace(3, 20, 37),
19861987
ignore_index = -1,
19871988
num_dist_bins = 38,
19881989
num_plddt_bins = 50,
19891990
num_pde_bins = 64,
19901991
num_pae_bins = 64,
1992+
sigma_data = 16,
19911993
loss_confidence_weight = 1e-4,
19921994
loss_distogram_weight = 1e-2,
19931995
loss_diffusion_weight = 4.,
@@ -2023,6 +2025,32 @@ def __init__(
20232025
relative_position_encoding_kwargs: dict = dict(
20242026
r_max = 32,
20252027
s_max = 2,
2028+
),
2029+
diffusion_module_kwargs: dict = dict(
2030+
single_cond_kwargs = dict(
2031+
num_transitions = 2,
2032+
transition_expansion_factor = 2,
2033+
),
2034+
pairwise_cond_kwargs = dict(
2035+
num_transitions = 2
2036+
),
2037+
atom_encoder_depth = 3,
2038+
atom_encoder_heads = 4,
2039+
token_transformer_depth = 24,
2040+
token_transformer_heads = 16,
2041+
atom_decoder_depth = 3,
2042+
atom_decoder_heads = 4
2043+
),
2044+
edm_kwargs: dict = dict(
2045+
sigma_min = 0.002,
2046+
sigma_max = 80,
2047+
rho = 7,
2048+
P_mean = -1.2,
2049+
P_std = 1.2,
2050+
S_churn = 80,
2051+
S_tmin = 0.05,
2052+
S_tmax = 50,
2053+
S_noise = 1.003,
20262054
)
20272055
):
20282056
super().__init__()
@@ -2091,6 +2119,27 @@ def __init__(
20912119
LinearNoBias(dim_pairwise, dim_pairwise)
20922120
)
20932121

2122+
# diffusion
2123+
2124+
self.diffusion_module = DiffusionModule(
2125+
dim_pairwise_trunk = dim_pairwise,
2126+
dim_pairwise_rel_pos_feats = dim_pairwise,
2127+
atoms_per_window = atoms_per_window,
2128+
dim_pairwise = dim_pairwise,
2129+
sigma_data = sigma_data,
2130+
dim_atom = dim_atom,
2131+
dim_atompair = dim_atompair,
2132+
dim_token = dim_token,
2133+
dim_single = dim_single + dim_single_inputs,
2134+
**diffusion_module_kwargs
2135+
)
2136+
2137+
self.edm = ElucidatedAtomDiffusion(
2138+
self.diffusion_module,
2139+
sigma_data = sigma_data,
2140+
**edm_kwargs
2141+
)
2142+
20942143
# logit heads
20952144

20962145
self.distogram_head = DistogramHead(
@@ -2116,11 +2165,11 @@ def __init__(
21162165
self.loss_confidence_weight = loss_confidence_weight
21172166
self.loss_diffusion_weight = loss_diffusion_weight
21182167

2119-
self.register_buffer('dummy', torch.tensor(0), persistent = False)
2168+
self.register_buffer('zero', torch.tensor(0.), persistent = False)
21202169

21212170
@property
21222171
def device(self):
2123-
return self.dummy.device
2172+
return self.zero.device
21242173

21252174
@typecheck
21262175
def forward(
@@ -2134,6 +2183,8 @@ def forward(
21342183
templates: Float['b t n n dt'],
21352184
template_mask: Bool['b t'],
21362185
num_recycling_steps: int = 1,
2186+
num_sample_steps: int | None = None,
2187+
atom_pos: Float['b m 3'] | None = None,
21372188
distance_labels: Int['b n n'] | None = None,
21382189
pae_labels: Int['b n n'] | None = None,
21392190
pde_labels: Int['b n n'] | None = None,
@@ -2228,23 +2279,52 @@ def forward(
22282279
# determine whether to return loss if any labels were to be passed in
22292280
# otherwise will sample the atomic coordinates
22302281

2282+
atom_pos_given = exists(atom_pos)
2283+
22312284
labels = (distance_labels, pae_labels, pde_labels, plddt_labels, resolved_labels)
2232-
return_loss = any([*map(exists, labels)])
2285+
has_labels = any([*map(exists, labels)])
2286+
2287+
return_loss = atom_pos_given or has_labels
2288+
2289+
# setup all the data necessary for conditioning the diffusion module
2290+
2291+
diffusion_cond = dict(
2292+
atom_feats = atom_feats,
2293+
atompair_feats = atompair_feats,
2294+
atom_mask = atom_mask,
2295+
mask = mask,
2296+
single_trunk_repr = single,
2297+
single_inputs_repr = single_inputs,
2298+
pairwise_trunk = pairwise,
2299+
pairwise_rel_pos_feats = relative_position_encoding
2300+
)
2301+
2302+
# if neither atom positions or any labels are passed in, sample a structure and return
22332303

22342304
if not return_loss:
2235-
return torch.randn((*atom_inputs.shape[:2], 3), device = self.device)
2305+
return self.edm.sample(num_sample_steps = num_sample_steps, **diffusion_cond)
2306+
2307+
# otherwise, noise and make it learn to denoise
2308+
2309+
diffusion_loss = self.zero
2310+
2311+
if exists(atom_pos):
2312+
diffusion_loss = self.edm(atom_pos, **diffusion_cond)
22362313

22372314
# calculate all logits and losses
22382315

22392316
ignore = self.ignore_index
22402317

2318+
distogram_loss = self.zero
2319+
22412320
if exists(distance_labels):
22422321
distance_labels = torch.where(pairwise_mask, distance_labels, ignore)
22432322
distogram_logits = self.distogram_head(pairwise)
22442323
distogram_loss = F.cross_entropy(distogram_logits, distance_labels, ignore_index = ignore)
22452324

22462325
loss = (
2247-
distogram_loss * self.loss_distogram_weight
2326+
distogram_loss * self.loss_distogram_weight +
2327+
diffusion_loss * self.loss_diffusion_weight
22482328
)
22492329

22502330
return loss

tests/test_readme.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def test_alphafold3():
256256

257257
msa = torch.randn(2, 7, seq_len, 64)
258258

259+
atom_pos = torch.randn(2, atom_seq_len, 3)
259260
distance_labels = torch.randint(0, 38, (2, seq_len, seq_len))
260261

261262
alphafold3 = Alphafold3(
@@ -274,7 +275,21 @@ def test_alphafold3():
274275
msa = msa,
275276
templates = template_feats,
276277
template_mask = template_mask,
278+
atom_pos = atom_pos,
277279
distance_labels = distance_labels
278280
)
279281

280282
loss.backward()
283+
284+
sampled_atom_pos = alphafold3(
285+
num_sample_steps = 16,
286+
atom_inputs = atom_inputs,
287+
atom_mask = atom_mask,
288+
atompair_feats = atompair_feats,
289+
additional_residue_feats = additional_residue_feats,
290+
msa = msa,
291+
templates = template_feats,
292+
template_mask = template_mask,
293+
)
294+
295+
assert sampled_atom_pos.ndim == 3

0 commit comments

Comments
 (0)