Skip to content

Commit 5d51920

Browse files
committed
Fix structure/lddt parity and add structure parity scripts
1 parent 9b9c1cc commit 5d51920

10 files changed

+608
-11
lines changed

docs/AGENT_NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
# Agent Notes
22

33
- Before running any command or editing any file in this project, write a short two-sentence summary of what I’m about to do so the user can follow along, and then proceed without waiting for confirmation.
4+
5+
- Current issue: StructureModule parity now matches through intermediates; `positions`, `frames`, `angles`, and `states` are within ~1e-5 of Python after fixing AngleResnet reshape (C-order) and rigid rotation layout. The remaining large mismatch is isolated to `lddt_head`/`plddt` (and `categorical_lddt`), while other heads are close; next step is to inspect and instrument the LDDT head path in Julia vs Python.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using Pkg
2+
Pkg.activate("/Users/benmurrell/JuliaM3/juliaESM"; io=devnull)
3+
4+
using NPZ
5+
using Statistics
6+
using ESMEmbed
7+
8+
ref = NPZ.npzread("/Users/benmurrell/JuliaM3/juliaESM/residue_constants_ref.npz")
9+
10+
function diff_stats(a, b)
11+
max_abs = maximum(abs.(a .- b))
12+
mean_abs = mean(abs.(a .- b))
13+
return max_abs, mean_abs
14+
end
15+
16+
function report(name, a, b)
17+
max_abs, mean_abs = diff_stats(Float32.(a), Float32.(b))
18+
println(name, " max_abs=", max_abs, " mean_abs=", mean_abs)
19+
end
20+
21+
report(
22+
"restype_rigid_group_default_frame",
23+
ESMEmbed.restype_rigid_group_default_frame,
24+
ref["restype_rigid_group_default_frame"],
25+
)
26+
report(
27+
"restype_atom14_to_rigid_group",
28+
ESMEmbed.restype_atom14_to_rigid_group,
29+
ref["restype_atom14_to_rigid_group"],
30+
)
31+
report(
32+
"restype_atom14_mask",
33+
ESMEmbed.restype_atom14_mask,
34+
ref["restype_atom14_mask"],
35+
)
36+
report(
37+
"restype_atom14_rigid_group_positions",
38+
ESMEmbed.restype_atom14_rigid_group_positions,
39+
ref["restype_atom14_rigid_group_positions"],
40+
)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
using Pkg
2+
Pkg.activate("/Users/benmurrell/JuliaM3/juliaESM"; io=devnull)
3+
4+
using NPZ
5+
using Statistics
6+
using ESMEmbed
7+
8+
ref = NPZ.npzread("/Users/benmurrell/JuliaM3/juliaESM/esmfold_structure_debug.npz")
9+
10+
path = "weights/esm.safetensors"
11+
reader = ESMEmbed.SafeTensors.Reader(path)
12+
(
13+
num_layers,
14+
embed_dim,
15+
attention_heads,
16+
c_s,
17+
c_z,
18+
sequence_head_width,
19+
pairwise_head_width,
20+
position_bins,
21+
num_blocks,
22+
lddt_head_hid_dim,
23+
) = ESMEmbed._infer_esmfold_full_config(reader)
24+
25+
alphabet = ESMEmbed.Alphabet_from_architecture("ESM-1b")
26+
esm = ESMEmbed.ESM2(
27+
num_layers,
28+
embed_dim,
29+
attention_heads;
30+
alphabet = alphabet,
31+
token_dropout = true,
32+
)
33+
34+
trunk_cfg = ESMEmbed.FoldingTrunkConfig(
35+
num_blocks,
36+
c_s,
37+
c_z,
38+
sequence_head_width,
39+
pairwise_head_width,
40+
position_bins,
41+
0f0,
42+
0f0,
43+
false,
44+
4,
45+
nothing,
46+
ESMEmbed.StructureModuleConfig(),
47+
)
48+
49+
cfg = ESMEmbed.ESMFoldConfig(; trunk=trunk_cfg, lddt_head_hid_dim=lddt_head_hid_dim, use_esm_attn_map=false)
50+
model = ESMEmbed.ESMFold(esm; cfg=cfg)
51+
ESMEmbed.load_esmfold_safetensors!(model, reader)
52+
53+
single = Float32.(ref["single"])
54+
pair = Float32.(ref["pair"])
55+
aatype = Int.(ref["aatype"])
56+
mask = Float32.(ref["mask"])
57+
58+
sm = model.trunk.structure_module
59+
out = sm(Dict(:single => single, :pair => pair), aatype, mask)
60+
61+
function diff_stats(a, b)
62+
max_abs = maximum(abs.(a .- b))
63+
mean_abs = mean(abs.(a .- b))
64+
return max_abs, mean_abs
65+
end
66+
67+
keys = (
68+
:frames,
69+
:sidechain_frames,
70+
:unnormalized_angles,
71+
:angles,
72+
:positions,
73+
:states,
74+
)
75+
76+
for k in keys
77+
ref_key = String(k)
78+
haskey(ref, ref_key) || continue
79+
a = Float32.(Array(out[k]))
80+
b = Float32.(ref[ref_key])
81+
max_abs, mean_abs = diff_stats(a, b)
82+
println("sm_", ref_key, " max_abs=", max_abs, " mean_abs=", mean_abs)
83+
end
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
using Pkg
2+
Pkg.activate("/Users/benmurrell/JuliaM3/juliaESM"; io=devnull)
3+
4+
using NPZ
5+
using Statistics
6+
using ESMEmbed
7+
using NNlib
8+
9+
ref = NPZ.npzread("/Users/benmurrell/JuliaM3/juliaESM/esmfold_structure_intermediate.npz")
10+
11+
path = "weights/esm.safetensors"
12+
reader = ESMEmbed.SafeTensors.Reader(path)
13+
(
14+
num_layers,
15+
embed_dim,
16+
attention_heads,
17+
c_s,
18+
c_z,
19+
sequence_head_width,
20+
pairwise_head_width,
21+
position_bins,
22+
num_blocks,
23+
lddt_head_hid_dim,
24+
) = ESMEmbed._infer_esmfold_full_config(reader)
25+
26+
alphabet = ESMEmbed.Alphabet_from_architecture("ESM-1b")
27+
esm = ESMEmbed.ESM2(
28+
num_layers,
29+
embed_dim,
30+
attention_heads;
31+
alphabet = alphabet,
32+
token_dropout = true,
33+
)
34+
35+
trunk_cfg = ESMEmbed.FoldingTrunkConfig(
36+
num_blocks,
37+
c_s,
38+
c_z,
39+
sequence_head_width,
40+
pairwise_head_width,
41+
position_bins,
42+
0f0,
43+
0f0,
44+
false,
45+
4,
46+
nothing,
47+
ESMEmbed.StructureModuleConfig(),
48+
)
49+
50+
cfg = ESMEmbed.ESMFoldConfig(; trunk=trunk_cfg, lddt_head_hid_dim=lddt_head_hid_dim, use_esm_attn_map=false)
51+
model = ESMEmbed.ESMFold(esm; cfg=cfg)
52+
ESMEmbed.load_esmfold_safetensors!(model, reader)
53+
54+
single = Float32.(ref["single"])
55+
pair = Float32.(ref["pair"])
56+
aatype = Int.(ref["aatype"])
57+
mask = Float32.(ref["mask"])
58+
59+
sm = model.trunk.structure_module
60+
61+
s_ln_s = sm.layer_norm_s(single)
62+
z_ln = sm.layer_norm_z(pair)
63+
s_initial = s_ln_s
64+
s_in = sm.linear_in(s_ln_s)
65+
66+
rigids = ESMEmbed.rigid_identity(size(s_in)[1:end-1], s_in; fmt=:quat)
67+
ipa_out = sm.ipa(s_in, z_ln, rigids, mask)
68+
s_after_ipa = s_in .+ ipa_out
69+
s_after_ipa = sm.ipa_dropout(s_after_ipa)
70+
s_after_ln = sm.layer_norm_ipa(s_after_ipa)
71+
s_after_transition = sm.transition(s_after_ln)
72+
73+
bb_update = sm.bb_update(s_after_transition)
74+
rigids1 = ESMEmbed.compose_q_update_vec(rigids, bb_update)
75+
76+
backb_to_global = ESMEmbed.Rigid(
77+
ESMEmbed.Rotation(rot_mats=ESMEmbed.get_rot_mats(rigids1.rots)),
78+
rigids1.trans,
79+
)
80+
backb_to_global = ESMEmbed.scale_translation(backb_to_global, sm.cfg.trans_scale_factor)
81+
82+
unnormalized_angles, angles = sm.angle_resnet(s_after_transition, s_initial)
83+
84+
default_frames, group_idx, atom_mask, lit_positions = ESMEmbed._init_residue_constants!(sm, angles)
85+
all_frames_to_global = ESMEmbed.torsion_angles_to_frames(backb_to_global, angles, aatype, default_frames)
86+
pred_xyz = ESMEmbed.frames_and_literature_positions_to_atom14_pos(
87+
all_frames_to_global,
88+
aatype,
89+
default_frames,
90+
group_idx,
91+
atom_mask,
92+
lit_positions,
93+
)
94+
95+
scaled_rigids = ESMEmbed.scale_translation(rigids1, sm.cfg.trans_scale_factor)
96+
97+
function diff_stats(a, b)
98+
max_abs = maximum(abs.(a .- b))
99+
mean_abs = mean(abs.(a .- b))
100+
return max_abs, mean_abs
101+
end
102+
103+
function report(name, a, b)
104+
max_abs, mean_abs = diff_stats(Float32.(a), Float32.(b))
105+
println(name, " max_abs=", max_abs, " mean_abs=", mean_abs)
106+
end
107+
108+
report("s_ln_s", s_ln_s, ref["s_ln_s"])
109+
report("z_ln", z_ln, ref["z_ln"])
110+
report("s_initial", s_initial, ref["s_initial"])
111+
report("s_in", s_in, ref["s_in"])
112+
report("ipa_out", ipa_out, ref["ipa_out"])
113+
report("s_after_ipa", s_after_ipa, ref["s_after_ipa"])
114+
report("s_after_ln", s_after_ln, ref["s_after_ln"])
115+
report("s_after_transition", s_after_transition, ref["s_after_transition"])
116+
report("bb_update", bb_update, ref["bb_update"])
117+
report("rigids_t7", ESMEmbed.to_tensor_7(rigids1), ref["rigids_t7"])
118+
report("backb_t7", ESMEmbed.to_tensor_7(backb_to_global), ref["backb_t7"])
119+
report("unnormalized_angles", unnormalized_angles, ref["unnormalized_angles"])
120+
report("angles", angles, ref["angles"])
121+
report("frames", ESMEmbed.to_tensor_7(scaled_rigids), ref["frames"])
122+
report("sidechain_frames", ESMEmbed.to_tensor_4x4(all_frames_to_global), ref["sidechain_frames"])
123+
report("positions", pred_xyz, ref["positions"])
124+
report("states", s_after_transition, ref["states"])
125+
126+
# --- torsion/frame intermediates (match OpenFold feats.py) ---
127+
default_frames, _, _, _ = ESMEmbed._init_residue_constants!(sm, angles)
128+
idx = aatype .+ 1
129+
df = permutedims(default_frames, (2, 3, 4, 1))
130+
df_sel = NNlib.gather(df, idx)
131+
default_4x4 = permutedims(df_sel, (4, 5, 1, 2, 3))
132+
133+
bb_shape = (size(angles)[1:end-2]..., 1, 2)
134+
bb_rot = ESMEmbed.zeros_like(angles, eltype(angles), bb_shape...)
135+
ESMEmbed._view_last2(bb_rot, 1, 2) .= 1
136+
alpha = cat(bb_rot, angles; dims=ndims(angles) - 1)
137+
138+
all_rots = ESMEmbed.zeros_like(angles, eltype(angles), size(default_4x4)...)
139+
ESMEmbed._view_last2(all_rots, 1, 1) .= 1
140+
ESMEmbed._view_last2(all_rots, 2, 2) .= ESMEmbed._view_last1(alpha, 2)
141+
ESMEmbed._view_last2(all_rots, 2, 3) .= -ESMEmbed._view_last1(alpha, 1)
142+
ESMEmbed._view_last2(all_rots, 3, 2) .= ESMEmbed._view_last1(alpha, 1)
143+
ESMEmbed._view_last2(all_rots, 3, 3) .= ESMEmbed._view_last1(alpha, 2)
144+
145+
default_r = ESMEmbed.rigid_from_tensor_4x4(default_4x4)
146+
all_rots_r = ESMEmbed.rigid_from_tensor_4x4(all_rots)
147+
all_frames = ESMEmbed.compose(default_r, all_rots_r)
148+
149+
chi2_frame_to_frame = ESMEmbed.rigid_index(all_frames, Colon(), Colon(), 6)
150+
chi3_frame_to_frame = ESMEmbed.rigid_index(all_frames, Colon(), Colon(), 7)
151+
chi4_frame_to_frame = ESMEmbed.rigid_index(all_frames, Colon(), Colon(), 8)
152+
153+
chi1_frame_to_bb = ESMEmbed.rigid_index(all_frames, Colon(), Colon(), 5)
154+
chi2_frame_to_bb = ESMEmbed.compose(chi1_frame_to_bb, chi2_frame_to_frame)
155+
chi3_frame_to_bb = ESMEmbed.compose(chi2_frame_to_bb, chi3_frame_to_frame)
156+
chi4_frame_to_bb = ESMEmbed.compose(chi3_frame_to_bb, chi4_frame_to_frame)
157+
158+
rot = ESMEmbed.get_rot_mats(all_frames.rots)
159+
trans = all_frames.trans
160+
rot_first = rot[:, :, 1:5, :, :]
161+
trans_first = trans[:, :, 1:5, :]
162+
rot_chi2 = reshape(ESMEmbed.get_rot_mats(chi2_frame_to_bb.rots), size(rot, 1), size(rot, 2), 1, 3, 3)
163+
rot_chi3 = reshape(ESMEmbed.get_rot_mats(chi3_frame_to_bb.rots), size(rot, 1), size(rot, 2), 1, 3, 3)
164+
rot_chi4 = reshape(ESMEmbed.get_rot_mats(chi4_frame_to_bb.rots), size(rot, 1), size(rot, 2), 1, 3, 3)
165+
trans_chi2 = reshape(chi2_frame_to_bb.trans, size(trans, 1), size(trans, 2), 1, 3)
166+
trans_chi3 = reshape(chi3_frame_to_bb.trans, size(trans, 1), size(trans, 2), 1, 3)
167+
trans_chi4 = reshape(chi4_frame_to_bb.trans, size(trans, 1), size(trans, 2), 1, 3)
168+
rot_new = cat(rot_first, rot_chi2, rot_chi3, rot_chi4; dims=3)
169+
trans_new = cat(trans_first, trans_chi2, trans_chi3, trans_chi4; dims=3)
170+
all_frames_to_bb = ESMEmbed.Rigid(ESMEmbed.Rotation(rot_mats=rot_new), trans_new)
171+
all_frames_to_global2 = ESMEmbed.compose(backb_to_global, all_frames_to_bb)
172+
173+
report("backb_rotmats", ESMEmbed.get_rot_mats(backb_to_global.rots), ref["backb_rotmats"])
174+
report("backb_trans", backb_to_global.trans, ref["backb_trans"])
175+
report("default_4x4", default_4x4, ref["default_4x4"])
176+
report("all_rots_4x4", all_rots, ref["all_rots_4x4"])
177+
report("all_frames_4x4", ESMEmbed.to_tensor_4x4(all_frames), ref["all_frames_4x4"])
178+
report("all_frames_to_bb_4x4", ESMEmbed.to_tensor_4x4(all_frames_to_bb), ref["all_frames_to_bb_4x4"])
179+
report("all_frames_to_global_4x4", ESMEmbed.to_tensor_4x4(all_frames_to_global2), ref["all_frames_to_global_4x4"])
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import argparse
2+
from pathlib import Path
3+
4+
import numpy as np
5+
import torch
6+
7+
from export_esmfold_full_from_safetensors import (
8+
SafeTensorsReader,
9+
build_model,
10+
load_esm2_weights,
11+
load_rest_weights,
12+
infer,
13+
batch_encode_sequences,
14+
)
15+
16+
17+
def main():
18+
parser = argparse.ArgumentParser()
19+
parser.add_argument("--safetensors", required=True)
20+
parser.add_argument("--sequence", default="ELLKKLLEELKG")
21+
parser.add_argument("--output", default="esmfold_structure_debug.npz")
22+
parser.add_argument("--num-recycles", type=int, default=0)
23+
args = parser.parse_args()
24+
25+
torch.set_grad_enabled(False)
26+
27+
with SafeTensorsReader(Path(args.safetensors)) as reader:
28+
model = build_model(reader, use_esm_attn_map=False)
29+
load_esm2_weights(model.esm, reader)
30+
load_rest_weights(model, reader)
31+
32+
model.esm.eval()
33+
model.esm_s_mlp.eval()
34+
model.embedding.eval()
35+
model.trunk.eval()
36+
model.distogram_head.eval()
37+
model.ptm_head.eval()
38+
model.lm_head.eval()
39+
model.lddt_head.eval()
40+
41+
# Run full trunk once to get s_s / s_z
42+
output = infer(model, args.sequence, num_recycles=args.num_recycles)
43+
s_s = output["s_s"]
44+
s_z = output["s_z"]
45+
46+
# Rebuild mask/residx consistent with infer
47+
aatype, mask, residx, _linker_mask, _chain_index = batch_encode_sequences(
48+
[args.sequence], 512, "G" * 25
49+
)
50+
aatype = aatype.to(s_s.device)
51+
mask = mask.to(s_s.device)
52+
residx = residx.to(s_s.device)
53+
54+
# Inputs to structure module
55+
single = model.trunk.trunk2sm_s(s_s)
56+
pair = model.trunk.trunk2sm_z(s_z)
57+
58+
sm_out = model.trunk.structure_module(
59+
{"single": single, "pair": pair},
60+
aatype,
61+
mask.float(),
62+
)
63+
64+
export = {
65+
"single": single.detach().cpu().numpy(),
66+
"pair": pair.detach().cpu().numpy(),
67+
"aatype": aatype.detach().cpu().numpy(),
68+
"mask": mask.detach().cpu().numpy(),
69+
"residx": residx.detach().cpu().numpy(),
70+
"s_s": s_s.detach().cpu().numpy(),
71+
"s_z": s_z.detach().cpu().numpy(),
72+
}
73+
for key in [
74+
"frames",
75+
"sidechain_frames",
76+
"unnormalized_angles",
77+
"angles",
78+
"positions",
79+
"states",
80+
]:
81+
export[key] = sm_out[key].detach().cpu().numpy()
82+
83+
np.savez(args.output, **export)
84+
85+
86+
if __name__ == "__main__":
87+
main()

0 commit comments

Comments
 (0)