Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
Onion = "fdebf6c2-71da-43a1-b539-c3bc3e09c5c6"
ProtInterop = "b3e4c6a1-2f5a-4d8c-9e7b-1a3c5d9f0e2b"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

Expand Down
3 changes: 2 additions & 1 deletion src/ESMFold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ using JSON
using HuggingFaceApi

include("device_utils.jl")
include("safetensors.jl")
using ProtInterop.ProtSafeTensors
using ProtInterop: compute_plddt, compute_predicted_aligned_error, compute_tm
include("constants.jl")

# GPU support
Expand Down
2 changes: 1 addition & 1 deletion src/esmfold_full.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ function (m::ESMFoldModel)(
structure[:ptm] = to_device(reshape(collect(ptm_vals), size(ptm_logits, 4)), ptm_logits, eltype(ptm_logits))

structure_update = compute_predicted_aligned_error(ptm_logits; max_bin=31, no_bins=m.distogram_bins)
for (k, v) in structure_update
for (k, v) in pairs(structure_update)
structure[k] = v
end

Expand Down
129 changes: 16 additions & 113 deletions src/openfold_infer_utils.jl
Original file line number Diff line number Diff line change
@@ -1,49 +1,18 @@
using NNlib
# OpenFold inference utilities.
# Atom14/atom37 data tables and confidence metrics come from ProtInterop.
# This file keeps only ESMFold-specific wrappers.

const _restype_atom14_to_atom37 = let
rows = Vector{Vector{Int}}()
for rt in restypes
atom_names = restype_name_to_atom14_names[restype_1to3[rt]]
push!(rows, [name == "" ? 1 : (atom_order[name] + 1) for name in atom_names])
end
push!(rows, fill(1, 14)) # UNK
reduce(vcat, (reshape(r, 1, :) for r in rows))
end
using ProtInterop: OF_RESTYPE_ATOM14_TO_ATOM37, OF_RESTYPE_ATOM37_TO_ATOM14,
OF_RESTYPE_ATOM14_MASK, OF_RESTYPE_ATOM37_MASK

const _restype_atom37_to_atom14 = let
rows = Vector{Vector{Int}}()
for rt in restypes
atom_names = restype_name_to_atom14_names[restype_1to3[rt]]
atom_name_to_idx14 = Dict(name => i for (i, name) in enumerate(atom_names) if name != "")
push!(rows, [get(atom_name_to_idx14, name, 1) for name in atom_types])
end
push!(rows, fill(1, length(atom_types))) # UNK
reduce(vcat, (reshape(r, 1, :) for r in rows))
end
# ── Aliases for backward compatibility ──────────────────────────────────────

const _restype_atom14_mask = let
rows = Vector{Vector{Float32}}()
for rt in restypes
atom_names = restype_name_to_atom14_names[restype_1to3[rt]]
push!(rows, [name == "" ? 0f0 : 1f0 for name in atom_names])
end
push!(rows, fill(0f0, 14)) # UNK
reduce(vcat, (reshape(r, 1, :) for r in rows))
end
const _restype_atom14_to_atom37 = OF_RESTYPE_ATOM14_TO_ATOM37
const _restype_atom37_to_atom14 = OF_RESTYPE_ATOM37_TO_ATOM14
const _restype_atom14_mask = OF_RESTYPE_ATOM14_MASK
const _restype_atom37_mask = OF_RESTYPE_ATOM37_MASK

const _restype_atom37_mask = let
mask = zeros(Float32, length(restypes), length(atom_types))
for (restype, restype_letter) in enumerate(restypes)
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names
atom_type = atom_order[atom_name] + 1
mask[restype, atom_type] = 1f0
end
end
mask = vcat(mask, zeros(Float32, 1, length(atom_types))) # UNK
mask
end
# ── make_atom14_masks! (ESMFold-specific: mutating API, modifies Dict) ──────

function make_atom14_masks!(protein::AbstractDict)
protein_aatype = protein[:aatype] .+ 1
Expand All @@ -67,82 +36,16 @@ function make_atom14_masks!(protein::AbstractDict)
return protein
end

function _calculate_bin_centers(boundaries::AbstractArray)
step = boundaries[2] - boundaries[1]
bin_centers = boundaries .+ step / 2
return vcat(bin_centers, bin_centers[end] + step)
end
# ── _calculate_expected_aligned_error (ESMFold-specific helper) ─────────────

function _calculate_expected_aligned_error(alignment_confidence_breaks, aligned_distance_error_probs)
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
bin_centers = ProtInterop._calculate_bin_centers(alignment_confidence_breaks)
bview = reshape(bin_centers, ntuple(_ -> 1, ndims(aligned_distance_error_probs) - 1)..., length(bin_centers))
expected = sum(aligned_distance_error_probs .* bview; dims=ndims(aligned_distance_error_probs))
expected = dropdims(expected; dims=ndims(expected))
return expected, bin_centers[end]
end

function compute_predicted_aligned_error(
logits;
max_bin::Int = 31,
no_bins::Int = 64,
)
boundaries_cpu = collect(range(0f0, Float32(max_bin); length=no_bins - 1))
bin_centers_cpu = _calculate_bin_centers(boundaries_cpu)

aligned_confidence_probs = NNlib.softmax(logits; dims=1)
bin_centers = to_device(bin_centers_cpu, logits, Float32)
bview = reshape(bin_centers, length(bin_centers), ntuple(_ -> 1, ndims(aligned_confidence_probs) - 1)...)
expected = sum(aligned_confidence_probs .* bview; dims=1)
expected = dropdims(expected; dims=1)

return Dict(
:aligned_confidence_probs => aligned_confidence_probs,
:predicted_aligned_error => expected,
:max_predicted_aligned_error => bin_centers_cpu[end],
)
end

function compute_tm(
logits;
residue_weights = nothing,
asym_id = nothing,
interface::Bool = false,
max_bin::Int = 31,
no_bins::Int = 64,
eps::Real = 1e-8,
)
if residue_weights === nothing
residue_weights = ones_like(logits, size(logits, 2))
end

boundaries = range(0f0, Float32(max_bin); length=no_bins - 1)
boundaries = to_device(collect(boundaries), logits, Float32)
bin_centers = _calculate_bin_centers(boundaries)

clipped_n = max(sum(residue_weights), 19)
d0 = 1.24f0 * (clipped_n - 15)^(1f0 / 3f0) - 1.8f0

probs = NNlib.softmax(logits; dims=1)
tm_per_bin = 1f0 ./ (1f0 .+ (bin_centers .^ 2) ./ (d0 ^ 2))
tm_view = reshape(tm_per_bin, length(tm_per_bin), ntuple(_ -> 1, ndims(probs) - 1)...)
predicted_tm_term = sum(probs .* tm_view; dims=1)
predicted_tm_term = dropdims(predicted_tm_term; dims=1)

n = size(predicted_tm_term, 1)
pair_mask = ones_like(predicted_tm_term, Int, n, n)
if interface && asym_id !== nothing
pair_mask .= (reshape(asym_id, :, 1) .!= reshape(asym_id, 1, :))
end

predicted_tm_term = predicted_tm_term .* pair_mask

pair_residue_weights = pair_mask .* (reshape(residue_weights, :, 1) .* reshape(residue_weights, 1, :))
denom = eps .+ sum(pair_residue_weights; dims=2)
normed_residue_mask = pair_residue_weights ./ denom
per_alignment = sum(predicted_tm_term .* normed_residue_mask; dims=2)
per_alignment = dropdims(per_alignment; dims=2)

weighted = per_alignment .* residue_weights
max_idx = argmax(weighted)
return per_alignment[max_idx]
end
# Confidence metrics (compute_plddt, compute_predicted_aligned_error, compute_tm)
# are now imported from ProtInterop via `using ProtInterop` in the module definition.
# ESMFold re-exports them.
2 changes: 1 addition & 1 deletion src/pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ function run_heads(model::ESMFoldModel, structure::AbstractDict, inputs::NamedTu

# Predicted aligned error
structure_update = compute_predicted_aligned_error(ptm_logits; max_bin=31, no_bins=model.distogram_bins)
for (k, v) in structure_update
for (k, v) in pairs(structure_update)
output[k] = v
end

Expand Down
Loading
Loading