Skip to content

Commit fe07e85

Browse files
committed
Port ESMFold folding stack
1 parent c266e7e commit fe07e85

22 files changed

+4017
-22
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
1414
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1515
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
1616
Onion = "fdebf6c2-71da-43a1-b539-c3bc3e09c5c6"
17+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1718
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1819
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1920

docs/AGENT_NOTES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Agent Notes
2+
3+
- Before running any command or editing any file in this project, write a short two-sentence summary of what I’m about to do so the user can follow along, and then proceed without waiting for confirmation.
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import argparse
2+
import numpy as np
3+
import torch
4+
5+
from export_esmfold_full_from_safetensors import (
6+
SafeTensorsReader,
7+
build_model,
8+
load_esm2_weights,
9+
load_rest_weights,
10+
compute_language_model_representations,
11+
infer_config,
12+
)
13+
14+
RESTYPES = [
15+
"A","R","N","D","C","Q","E","G","H","I","L","K","M","F","P","S","T","W","Y","V",
16+
]
17+
RESTYPES_WITH_X = RESTYPES + ["X"]
18+
RESTYPE_ORDER_WITH_X = {aa: i for i, aa in enumerate(RESTYPES_WITH_X)}
19+
20+
21+
def sequence_to_af2_indices(seq: str):
22+
return [RESTYPE_ORDER_WITH_X.get(ch, RESTYPE_ORDER_WITH_X["X"]) for ch in seq]
23+
24+
25+
def main():
26+
parser = argparse.ArgumentParser()
27+
parser.add_argument("--safetensors", required=True)
28+
parser.add_argument("--sequence", default="ELLKKLLEELKG")
29+
parser.add_argument("--output", default="esmfold_block0_debug.npz")
30+
args = parser.parse_args()
31+
32+
torch.set_grad_enabled(False)
33+
34+
from pathlib import Path
35+
with SafeTensorsReader(Path(args.safetensors)) as reader:
36+
model = build_model(reader, use_esm_attn_map=False)
37+
load_esm2_weights(model.esm, reader)
38+
load_rest_weights(model, reader)
39+
40+
model.esm.eval()
41+
model.esm_s_mlp.eval()
42+
model.embedding.eval()
43+
model.trunk.eval()
44+
model.distogram_head.eval()
45+
model.ptm_head.eval()
46+
model.lm_head.eval()
47+
model.lddt_head.eval()
48+
49+
# build s_s_0, s_z_0
50+
seq = args.sequence
51+
aa = torch.tensor([sequence_to_af2_indices(seq)], dtype=torch.long)
52+
mask = torch.ones_like(aa)
53+
54+
esmaa = model.af2_to_esm[(aa + 1).masked_fill(mask != 1, 0)]
55+
esm_s, _ = compute_language_model_representations(model, esmaa, use_esm_attn_map=False)
56+
esm_s = esm_s.to(model.esm_s_combine.dtype).detach()
57+
esm_s = (torch.softmax(model.esm_s_combine, 0).unsqueeze(0) @ esm_s).squeeze(2)
58+
s_s_0 = model.esm_s_mlp(esm_s) + model.embedding(aa)
59+
B, L, _ = s_s_0.shape
60+
s_z_0 = s_s_0.new_zeros(B, L, L, model.cfg["c_z"])
61+
62+
block = model.trunk.blocks[0]
63+
tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1)
64+
65+
bias = block.pair_to_sequence(s_z_0)
66+
seq_ln = block.layernorm_1(s_s_0)
67+
proj_out = block.seq_attention.proj(seq_ln)
68+
B, L, C3 = proj_out.shape
69+
H = block.seq_attention.num_heads
70+
head_width = block.seq_attention.head_width
71+
t = proj_out.view(B, L, H, 3 * head_width)
72+
q = t[..., :head_width]
73+
k = t[..., head_width:2 * head_width]
74+
v = t[..., 2 * head_width:3 * head_width]
75+
seq_attn_out, attn = block.seq_attention(seq_ln, mask=mask, bias=bias)
76+
g_proj_out = block.seq_attention.g_proj(seq_ln)
77+
q_py = q.permute(0, 2, 1, 3)
78+
k_py = k.permute(0, 2, 1, 3)
79+
logits = torch.einsum("...qc,...kc->...qk", q_py, k_py)
80+
seq_state_attn = s_s_0 + seq_attn_out
81+
seq_state_mlp = block.mlp_seq(seq_state_attn)
82+
83+
# manual attention to capture pre-o_proj outputs
84+
t_attn = block.seq_attention.proj(seq_ln)
85+
t_attn = t_attn.reshape(B, L, H, 3 * head_width).permute(0, 2, 1, 3)
86+
q2 = t_attn[..., :head_width]
87+
k2 = t_attn[..., head_width:2 * head_width]
88+
v2 = t_attn[..., 2 * head_width:3 * head_width]
89+
q2 = q2 * block.seq_attention.rescale_factor
90+
a2 = torch.einsum("...qc,...kc->...qk", q2, k2)
91+
a2 = a2 + bias.permute(0, 3, 1, 2)
92+
if mask is not None:
93+
a2 = a2.masked_fill(mask[:, None, None, :].expand_as(a2) == 0, float("-inf"))
94+
a2 = torch.softmax(a2, dim=-1)
95+
y2 = torch.einsum("...hqk,...hkc->...qhc", a2, v2)
96+
y2 = y2.reshape(B, L, H * head_width)
97+
if block.seq_attention.gated:
98+
y2_gated = block.seq_attention.g_proj(seq_ln).sigmoid() * y2
99+
else:
100+
y2_gated = y2
101+
102+
pair_state = s_z_0 + block.sequence_to_pair(seq_state_mlp)
103+
104+
tri_mul_out = block.tri_mul_out(pair_state, mask=tri_mask)
105+
pair_state = pair_state + tri_mul_out
106+
107+
tri_mul_in = block.tri_mul_in(pair_state, mask=tri_mask)
108+
pair_state = pair_state + tri_mul_in
109+
110+
pair_state_before_start = pair_state
111+
tri_att_start = block.tri_att_start(pair_state_before_start, mask=tri_mask, chunk_size=None)
112+
pair_state = pair_state_before_start + tri_att_start
113+
114+
pair_state_before_end = pair_state
115+
tri_att_end = block.tri_att_end(pair_state_before_end, mask=tri_mask, chunk_size=None)
116+
pair_state = pair_state_before_end + tri_att_end
117+
118+
# Triangle attention internals (start/end) to pinpoint mismatches
119+
from openfold.utils.tensor_utils import permute_final_dims, flatten_final_dims
120+
121+
def tri_attn_internals(ta, x, mask):
122+
if not ta.starting:
123+
x = x.transpose(-2, -3)
124+
mask = mask.transpose(-1, -2)
125+
x = ta.layer_norm(x)
126+
mask_bias = (ta.inf * (mask - 1))[..., :, None, None, :]
127+
triangle_bias = permute_final_dims(ta.linear(x), (2, 0, 1))
128+
triangle_bias = triangle_bias.unsqueeze(-4)
129+
q, k, v = ta.mha._prep_qkv(x, x, apply_scale=True)
130+
logits = torch.matmul(q, k.transpose(-1, -2))
131+
logits = logits + mask_bias + triangle_bias
132+
attn = torch.softmax(logits, dim=-1)
133+
o = torch.matmul(attn, v) # [*, H, Q, C]
134+
o = o.transpose(-2, -3) # [*, Q, H, C]
135+
if ta.mha.linear_g is not None:
136+
g = torch.sigmoid(ta.mha.linear_g(x))
137+
g = g.view(g.shape[:-1] + (ta.mha.no_heads, -1))
138+
o = o * g
139+
o_flat = flatten_final_dims(o, 2) # [*, Q, H*C]
140+
return logits, attn, o_flat, mask_bias, triangle_bias, x, q, k, v
141+
142+
tri_att_start_logits, tri_att_start_attn, tri_att_start_pre_o, tri_att_start_mask_bias, tri_att_start_triangle_bias, tri_att_start_ln, tri_att_start_q, tri_att_start_k, tri_att_start_v = tri_attn_internals(
143+
block.tri_att_start, pair_state_before_start, tri_mask
144+
)
145+
tri_att_end_logits, tri_att_end_attn, tri_att_end_pre_o, tri_att_end_mask_bias, tri_att_end_triangle_bias, tri_att_end_ln, tri_att_end_q, tri_att_end_k, tri_att_end_v = tri_attn_internals(
146+
block.tri_att_end, pair_state_before_end, tri_mask
147+
)
148+
149+
pair_state_mlp = block.mlp_pair(pair_state)
150+
151+
export = {
152+
"s_s_0": s_s_0.cpu().numpy(),
153+
"s_z_0": s_z_0.cpu().numpy(),
154+
"bias": bias.cpu().numpy(),
155+
"seq_ln": seq_ln.cpu().numpy(),
156+
"proj_out": proj_out.cpu().numpy(),
157+
"seq_attn_out": seq_attn_out.cpu().numpy(),
158+
"seq_attn_pre_o_proj": y2_gated.cpu().numpy(),
159+
"seq_attn_pre_gate": y2.cpu().numpy(),
160+
"g_proj_out": g_proj_out.cpu().numpy(),
161+
"logits": logits.cpu().numpy(),
162+
"q": q.cpu().numpy(),
163+
"k": k.cpu().numpy(),
164+
"v": v.cpu().numpy(),
165+
"attn": attn.cpu().numpy(),
166+
"seq_state_attn": seq_state_attn.cpu().numpy(),
167+
"seq_state_mlp": seq_state_mlp.cpu().numpy(),
168+
"pair_state_seq2pair": (s_z_0 + block.sequence_to_pair(seq_state_mlp)).cpu().numpy(),
169+
"tri_mul_out": tri_mul_out.cpu().numpy(),
170+
"tri_mul_in": tri_mul_in.cpu().numpy(),
171+
"pair_state_before_start": pair_state_before_start.cpu().numpy(),
172+
"pair_state_before_end": pair_state_before_end.cpu().numpy(),
173+
"tri_att_start": tri_att_start.cpu().numpy(),
174+
"tri_att_end": tri_att_end.cpu().numpy(),
175+
"tri_att_start_logits": tri_att_start_logits.cpu().numpy(),
176+
"tri_att_start_attn": tri_att_start_attn.cpu().numpy(),
177+
"tri_att_start_pre_o": tri_att_start_pre_o.cpu().numpy(),
178+
"tri_att_start_mask_bias": tri_att_start_mask_bias.cpu().numpy(),
179+
"tri_att_start_triangle_bias": tri_att_start_triangle_bias.cpu().numpy(),
180+
"tri_att_start_ln": tri_att_start_ln.cpu().numpy(),
181+
"tri_att_start_q": tri_att_start_q.cpu().numpy(),
182+
"tri_att_start_k": tri_att_start_k.cpu().numpy(),
183+
"tri_att_start_v": tri_att_start_v.cpu().numpy(),
184+
"tri_att_end_logits": tri_att_end_logits.cpu().numpy(),
185+
"tri_att_end_attn": tri_att_end_attn.cpu().numpy(),
186+
"tri_att_end_pre_o": tri_att_end_pre_o.cpu().numpy(),
187+
"tri_att_end_mask_bias": tri_att_end_mask_bias.cpu().numpy(),
188+
"tri_att_end_triangle_bias": tri_att_end_triangle_bias.cpu().numpy(),
189+
"tri_att_end_ln": tri_att_end_ln.cpu().numpy(),
190+
"tri_att_end_q": tri_att_end_q.cpu().numpy(),
191+
"tri_att_end_k": tri_att_end_k.cpu().numpy(),
192+
"tri_att_end_v": tri_att_end_v.cpu().numpy(),
193+
"pair_state_final": pair_state_mlp.cpu().numpy(),
194+
}
195+
np.savez(args.output, **export)
196+
197+
198+
if __name__ == "__main__":
199+
main()

0 commit comments

Comments
 (0)