Skip to content

Commit a2dd442

Browse files
authored
Merge pull request #6 from MurrellGroup/consolidation
Consolidation
2 parents 52fed45 + aa72f22 commit a2dd442

File tree

6 files changed

+81
-176
lines changed

6 files changed

+81
-176
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
1515
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1616
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
1717
Onion = "fdebf6c2-71da-43a1-b539-c3bc3e09c5c6"
18+
ProtInterop = "b3e4c6a1-2f5a-4d8c-9e7b-1a3c5d9f0e2b"
1819
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1920
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2021

src/ESMFold.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ using JSON
1010
using HuggingFaceApi
1111

1212
include("device_utils.jl")
13-
include("safetensors.jl")
13+
using ProtInterop.ProtSafeTensors
14+
using ProtInterop: compute_plddt, compute_predicted_aligned_error, compute_tm
1415
include("constants.jl")
1516

1617
# GPU support

src/esmfold_full.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ function (m::ESMFoldModel)(
188188
structure[:ptm] = to_device(reshape(collect(ptm_vals), size(ptm_logits, 4)), ptm_logits, eltype(ptm_logits))
189189

190190
structure_update = compute_predicted_aligned_error(ptm_logits; max_bin=31, no_bins=m.distogram_bins)
191-
for (k, v) in structure_update
191+
for (k, v) in pairs(structure_update)
192192
structure[k] = v
193193
end
194194

src/openfold_infer_utils.jl

Lines changed: 16 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,18 @@
1-
using NNlib
1+
# OpenFold inference utilities.
2+
# Atom14/atom37 data tables and confidence metrics come from ProtInterop.
3+
# This file keeps only ESMFold-specific wrappers.
24

3-
const _restype_atom14_to_atom37 = let
4-
rows = Vector{Vector{Int}}()
5-
for rt in restypes
6-
atom_names = restype_name_to_atom14_names[restype_1to3[rt]]
7-
push!(rows, [name == "" ? 1 : (atom_order[name] + 1) for name in atom_names])
8-
end
9-
push!(rows, fill(1, 14)) # UNK
10-
reduce(vcat, (reshape(r, 1, :) for r in rows))
11-
end
5+
using ProtInterop: OF_RESTYPE_ATOM14_TO_ATOM37, OF_RESTYPE_ATOM37_TO_ATOM14,
6+
OF_RESTYPE_ATOM14_MASK, OF_RESTYPE_ATOM37_MASK
127

13-
const _restype_atom37_to_atom14 = let
14-
rows = Vector{Vector{Int}}()
15-
for rt in restypes
16-
atom_names = restype_name_to_atom14_names[restype_1to3[rt]]
17-
atom_name_to_idx14 = Dict(name => i for (i, name) in enumerate(atom_names) if name != "")
18-
push!(rows, [get(atom_name_to_idx14, name, 1) for name in atom_types])
19-
end
20-
push!(rows, fill(1, length(atom_types))) # UNK
21-
reduce(vcat, (reshape(r, 1, :) for r in rows))
22-
end
8+
# ── Aliases for backward compatibility ──────────────────────────────────────
239

24-
const _restype_atom14_mask = let
25-
rows = Vector{Vector{Float32}}()
26-
for rt in restypes
27-
atom_names = restype_name_to_atom14_names[restype_1to3[rt]]
28-
push!(rows, [name == "" ? 0f0 : 1f0 for name in atom_names])
29-
end
30-
push!(rows, fill(0f0, 14)) # UNK
31-
reduce(vcat, (reshape(r, 1, :) for r in rows))
32-
end
10+
const _restype_atom14_to_atom37 = OF_RESTYPE_ATOM14_TO_ATOM37
11+
const _restype_atom37_to_atom14 = OF_RESTYPE_ATOM37_TO_ATOM14
12+
const _restype_atom14_mask = OF_RESTYPE_ATOM14_MASK
13+
const _restype_atom37_mask = OF_RESTYPE_ATOM37_MASK
3314

34-
const _restype_atom37_mask = let
35-
mask = zeros(Float32, length(restypes), length(atom_types))
36-
for (restype, restype_letter) in enumerate(restypes)
37-
restype_name = restype_1to3[restype_letter]
38-
atom_names = residue_atoms[restype_name]
39-
for atom_name in atom_names
40-
atom_type = atom_order[atom_name] + 1
41-
mask[restype, atom_type] = 1f0
42-
end
43-
end
44-
mask = vcat(mask, zeros(Float32, 1, length(atom_types))) # UNK
45-
mask
46-
end
15+
# ── make_atom14_masks! (ESMFold-specific: mutating API, modifies Dict) ──────
4716

4817
function make_atom14_masks!(protein::AbstractDict)
4918
protein_aatype = protein[:aatype] .+ 1
@@ -67,82 +36,16 @@ function make_atom14_masks!(protein::AbstractDict)
6736
return protein
6837
end
6938

70-
function _calculate_bin_centers(boundaries::AbstractArray)
71-
step = boundaries[2] - boundaries[1]
72-
bin_centers = boundaries .+ step / 2
73-
return vcat(bin_centers, bin_centers[end] + step)
74-
end
39+
# ── _calculate_expected_aligned_error (ESMFold-specific helper) ─────────────
7540

7641
function _calculate_expected_aligned_error(alignment_confidence_breaks, aligned_distance_error_probs)
77-
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
42+
bin_centers = ProtInterop._calculate_bin_centers(alignment_confidence_breaks)
7843
bview = reshape(bin_centers, ntuple(_ -> 1, ndims(aligned_distance_error_probs) - 1)..., length(bin_centers))
7944
expected = sum(aligned_distance_error_probs .* bview; dims=ndims(aligned_distance_error_probs))
8045
expected = dropdims(expected; dims=ndims(expected))
8146
return expected, bin_centers[end]
8247
end
8348

84-
function compute_predicted_aligned_error(
85-
logits;
86-
max_bin::Int = 31,
87-
no_bins::Int = 64,
88-
)
89-
boundaries_cpu = collect(range(0f0, Float32(max_bin); length=no_bins - 1))
90-
bin_centers_cpu = _calculate_bin_centers(boundaries_cpu)
91-
92-
aligned_confidence_probs = NNlib.softmax(logits; dims=1)
93-
bin_centers = to_device(bin_centers_cpu, logits, Float32)
94-
bview = reshape(bin_centers, length(bin_centers), ntuple(_ -> 1, ndims(aligned_confidence_probs) - 1)...)
95-
expected = sum(aligned_confidence_probs .* bview; dims=1)
96-
expected = dropdims(expected; dims=1)
97-
98-
return Dict(
99-
:aligned_confidence_probs => aligned_confidence_probs,
100-
:predicted_aligned_error => expected,
101-
:max_predicted_aligned_error => bin_centers_cpu[end],
102-
)
103-
end
104-
105-
function compute_tm(
106-
logits;
107-
residue_weights = nothing,
108-
asym_id = nothing,
109-
interface::Bool = false,
110-
max_bin::Int = 31,
111-
no_bins::Int = 64,
112-
eps::Real = 1e-8,
113-
)
114-
if residue_weights === nothing
115-
residue_weights = ones_like(logits, size(logits, 2))
116-
end
117-
118-
boundaries = range(0f0, Float32(max_bin); length=no_bins - 1)
119-
boundaries = to_device(collect(boundaries), logits, Float32)
120-
bin_centers = _calculate_bin_centers(boundaries)
121-
122-
clipped_n = max(sum(residue_weights), 19)
123-
d0 = 1.24f0 * (clipped_n - 15)^(1f0 / 3f0) - 1.8f0
124-
125-
probs = NNlib.softmax(logits; dims=1)
126-
tm_per_bin = 1f0 ./ (1f0 .+ (bin_centers .^ 2) ./ (d0 ^ 2))
127-
tm_view = reshape(tm_per_bin, length(tm_per_bin), ntuple(_ -> 1, ndims(probs) - 1)...)
128-
predicted_tm_term = sum(probs .* tm_view; dims=1)
129-
predicted_tm_term = dropdims(predicted_tm_term; dims=1)
130-
131-
n = size(predicted_tm_term, 1)
132-
pair_mask = ones_like(predicted_tm_term, Int, n, n)
133-
if interface && asym_id !== nothing
134-
pair_mask .= (reshape(asym_id, :, 1) .!= reshape(asym_id, 1, :))
135-
end
136-
137-
predicted_tm_term = predicted_tm_term .* pair_mask
138-
139-
pair_residue_weights = pair_mask .* (reshape(residue_weights, :, 1) .* reshape(residue_weights, 1, :))
140-
denom = eps .+ sum(pair_residue_weights; dims=2)
141-
normed_residue_mask = pair_residue_weights ./ denom
142-
per_alignment = sum(predicted_tm_term .* normed_residue_mask; dims=2)
143-
per_alignment = dropdims(per_alignment; dims=2)
144-
145-
weighted = per_alignment .* residue_weights
146-
max_idx = argmax(weighted)
147-
return per_alignment[max_idx]
148-
end
49+
# Confidence metrics (compute_plddt, compute_predicted_aligned_error, compute_tm)
50+
# are now imported from ProtInterop via `using ProtInterop` in the module definition.
51+
# ESMFold re-exports them.

src/pipeline.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ function run_heads(model::ESMFoldModel, structure::AbstractDict, inputs::NamedTu
342342

343343
# Predicted aligned error
344344
structure_update = compute_predicted_aligned_error(ptm_logits; max_bin=31, no_bins=model.distogram_bins)
345-
for (k, v) in structure_update
345+
for (k, v) in pairs(structure_update)
346346
output[k] = v
347347
end
348348

0 commit comments

Comments
 (0)