Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ version = "1.0.0-DEV"

[deps]
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
BatchedTransformations = "8ba27c4b-52b5-4b10-bc66-a4fda05aa11b"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Einops = "e3ce28c8-8bfb-4704-8add-e3e7f14b55c9"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
HuggingFaceApi = "3cc741c3-0c9d-4fbe-84fa-cdec264173de"
Expand All @@ -17,9 +19,11 @@ Onion = "fdebf6c2-71da-43a1-b539-c3bc3e09c5c6"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
BFloat16s = "0.5, 0.6"
ChainRulesCore = "1.26.0"
Einops = "0.1.14"
Flux = "0.16.9"
HuggingFaceApi = "0.1"
Expand All @@ -28,6 +32,7 @@ NNlib = "0.9.33"
NPZ = "0.4.3"
Onion = "0.2.18"
SpecialFunctions = "2"
Zygote = "0.7.10"
julia = "1.10"

[extras]
Expand Down
61 changes: 59 additions & 2 deletions scripts/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,77 @@ open(generated_path, "w") do io
end

expected = read(expected_path, String)
function _parse_float_field(line::AbstractString, r::UnitRange{Int})
last(r) > lastindex(line) && return nothing
s = strip(line[r])
isempty(s) && return nothing
return tryparse(Float64, s)
end

function _pdb_line_match(expected::AbstractString, got::AbstractString; tol::Float64=0.02)
if startswith(expected, "ATOM") || startswith(expected, "HETATM")
startswith(got, "ATOM") || startswith(got, "HETATM") || return false
# Compare non-float fields with fixed-column slices.
slices = (
1:6, # record
7:11, # serial
13:16, # atom name
17:17, # alt loc
18:20, # res name
22:22, # chain id
23:26, # res seq
27:27, # insertion code
)
for r in slices
last(r) <= lastindex(expected) || return expected == got
last(r) <= lastindex(got) || return false
if expected[r] != got[r]
return false
end
end
# Float fields: x,y,z, occupancy, tempFactor
fields = (31:38, 39:46, 47:54, 55:60, 61:66)
for r in fields
e = _parse_float_field(expected, r)
g = _parse_float_field(got, r)
if e === nothing || g === nothing
return expected == got
end
if abs(e - g) > tol
return false
end
end
# Element symbol (optional)
if lastindex(expected) >= 78 && lastindex(got) >= 78
if expected[77:78] != got[77:78]
return false
end
end
return true
end
return expected == got
end

if expected != pdb
println("PDB mismatch.")
exp_lines = split(expected, '\n')
got_lines = split(pdb, '\n')
max_lines = min(length(exp_lines), length(got_lines))
matched = true
for i in 1:max_lines
if exp_lines[i] != got_lines[i]
if !_pdb_line_match(exp_lines[i], got_lines[i])
matched = false
println("First diff at line ", i)
println("expected: ", exp_lines[i])
println("got: ", got_lines[i])
break
end
end
error("PDB does not match expected output.")
if !matched || length(exp_lines) != length(got_lines)
error("PDB does not match expected output.")
else
println("PDB matches within tolerance.")
end
end

metrics = ESMFold.confidence_metrics(output)
Expand Down
4 changes: 3 additions & 1 deletion src/ESMFold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ export sequence_to_af2_indices
export infer, infer_pdb, infer_pdbs, output_to_pdb
export confidence_metrics
export set_training!, is_training
export make_atom14_masks!
export compute_tm, compute_predicted_aligned_error, categorical_lddt

include("alphabet.jl")
include("residue_constants.jl")
include("openfold_utils.jl")
include("layernorm_last.jl")
include("layernorm.jl")
include("rigid.jl")
include("openfold_feats.jl")
include("esmfold_misc.jl")
Expand Down
35 changes: 10 additions & 25 deletions src/device_utils.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,20 @@
# Device-agnostic helpers

function zeros_like(x::AbstractArray, dims::Int...)
y = similar(x, eltype(x), dims...)
fill!(y, zero(eltype(x)))
return y
end
using ChainRulesCore

function zeros_like(x::AbstractArray, ::Type{T}, dims::Int...) where {T}
y = similar(x, T, dims...)
fill!(y, zero(T))
return y
end
like(v, x::AbstractArray, args...) = @ignore_derivatives fill!(similar(x, args...), v)

function ones_like(x::AbstractArray, dims::Int...)
y = similar(x, eltype(x), dims...)
fill!(y, one(eltype(x)))
return y
end

function ones_like(x::AbstractArray, ::Type{T}, dims::Int...) where {T}
y = similar(x, T, dims...)
fill!(y, one(T))
return y
end
zeros_like(x::AbstractArray, args...) = like(false, x, args...)
ones_like(x::AbstractArray, args...) = like(true, x, args...)

function to_device(x::AbstractArray, like::AbstractArray, ::Type{T}=eltype(x)) where {T}
y = similar(like, T, size(x))
y .= T.(x)
return y
return @ignore_derivatives begin
y = similar(like, T, size(x))
y .= T.(x)
y
end
end

function to_device(x::Number, like::AbstractArray, ::Type{T}=typeof(x)) where {T}
return T(x)
return @ignore_derivatives T(x)
end
135 changes: 89 additions & 46 deletions src/esmfold_full.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ end
@layer ESMFoldLDDTHead

function ESMFoldLDDTHead(c_in::Int, c_hidden::Int, c_out::Int)
norm = LayerNormLast(c_in)
linear_1 = LinearLast(c_in, c_hidden)
linear_2 = LinearLast(c_hidden, c_hidden)
linear_3 = LinearLast(c_hidden, c_out)
norm = LayerNormFirst(c_in)
linear_1 = LinearFirst(c_in, c_hidden)
linear_2 = LinearFirst(c_hidden, c_hidden)
linear_3 = LinearFirst(c_hidden, c_out)
return ESMFoldLDDTHead(norm, linear_1, linear_2, linear_3)
end

Expand Down Expand Up @@ -65,9 +65,9 @@ function ESMFoldModel(
trunk = FoldingTrunk(cfg=cfg.trunk)

distogram_bins = 64
distogram_head = LinearLast(c_z, distogram_bins)
ptm_head = LinearLast(c_z, distogram_bins)
lm_head = LinearLast(c_s, embed.n_tokens_embed)
distogram_head = LinearFirst(c_z, distogram_bins)
ptm_head = LinearFirst(c_z, distogram_bins)
lm_head = LinearFirst(c_s, embed.n_tokens_embed)
lddt_bins = 50
lddt_head = ESMFoldLDDTHead(cfg.trunk.structure_module.c_s, cfg.lddt_head_hid_dim, 37 * lddt_bins)

Expand All @@ -85,10 +85,10 @@ function ESMFoldModel(
end

function _default_residx(aa::AbstractArray)
L = size(aa, 2)
L = size(aa, 1)
residx = collect(0:(L - 1))
residx = reshape(residx, 1, L)
residx = repeat(residx, size(aa, 1), 1)
residx = reshape(residx, L, 1)
residx = repeat(residx, 1, size(aa, 2))
return to_device(residx, aa, eltype(residx))
end

Expand All @@ -111,33 +111,35 @@ function (m::ESMFoldModel)(
masking_pattern = to_device(masking_pattern, aa, eltype(aa))
end

aa_fl = permutedims(aa, (2, 1))
mask_fl = permutedims(mask, (2, 1))
masking_pattern_fl = masking_pattern === nothing ? nothing : permutedims(masking_pattern, (2, 1))

embed_out = m.embed(
aa;
mask = mask,
masking_pattern = masking_pattern,
aa_fl;
mask = mask_fl,
masking_pattern = masking_pattern_fl,
return_pair = m.cfg.use_esm_attn_map,
)

if m.cfg.use_esm_attn_map
s_s_0_cf = embed_out.sequence
s_z_0_cf = embed_out.pair
s_s_0 = embed_out.sequence
s_z_0 = embed_out.pair
else
s_s_0_cf = embed_out
s_z_0_cf = nothing
s_s_0 = embed_out
s_z_0 = nothing
end

s_s_0 = permutedims(s_s_0_cf, (3, 2, 1))
s_z_0 = if s_z_0_cf === nothing
# (B, L, L, c_z)
s_z_0 = if s_z_0 === nothing
zeros_like(
s_s_0,
size(s_s_0, 1),
m.cfg.trunk.pairwise_state_dim,
size(s_s_0, 2),
size(s_s_0, 2),
m.cfg.trunk.pairwise_state_dim,
size(s_s_0, 3),
)
else
permutedims(s_z_0_cf, (4, 2, 3, 1))
s_z_0
end

structure = m.trunk(
Expand All @@ -159,12 +161,15 @@ function (m::ESMFoldModel)(
make_atom14_masks!(structure)

for k in (:atom14_atom_exists, :atom37_atom_exists)
structure[k] .*= reshape(mask, size(mask, 1), size(mask, 2), 1)
structure[k] .*= reshape(mask, 1, size(mask, 1), size(mask, 2))
end
structure[:residue_index] = residx

lddt_logits = m.lddt_head(structure[:states])
lddt_head = _reshape_last_corder(lddt_logits, 37, m.lddt_bins)
states = structure[:states]
states_cfirst = permutedims(states, (2, 1, 3, 4))
lddt_logits = m.lddt_head(states_cfirst)
lddt_tmp = _reshape_first_corder(lddt_logits, 37, m.lddt_bins)
lddt_head = permutedims(lddt_tmp, (3, 4, 5, 2, 1))
structure[:lddt_head] = lddt_head

plddt = categorical_lddt(lddt_head[end, :, :, :, :], bins=m.lddt_bins)
Expand All @@ -173,13 +178,13 @@ function (m::ESMFoldModel)(
ptm_logits = m.ptm_head(structure[:s_z])
structure[:ptm_logits] = ptm_logits

seqlen = sum(mask .== 1; dims=2)
ptm_vals = Vector{eltype(ptm_logits)}(undef, size(ptm_logits, 1))
for b in 1:size(ptm_logits, 1)
sl = Int(seqlen[b])
ptm_vals[b] = compute_tm(ptm_logits[b, 1:sl, 1:sl, :]; max_bin=31, no_bins=m.distogram_bins)
seqlen = sum(mask .== 1; dims=1)
ptm_vals = Vector{eltype(ptm_logits)}(undef, size(ptm_logits, 4))
for b in 1:size(ptm_logits, 4)
sl = Int(seqlen[1, b])
ptm_vals[b] = compute_tm(ptm_logits[:, 1:sl, 1:sl, b]; max_bin=31, no_bins=m.distogram_bins)
end
structure[:ptm] = to_device(reshape(collect(ptm_vals), size(ptm_logits, 1)), ptm_logits, eltype(ptm_logits))
structure[:ptm] = to_device(reshape(collect(ptm_vals), size(ptm_logits, 4)), ptm_logits, eltype(ptm_logits))

structure_update = compute_predicted_aligned_error(ptm_logits; max_bin=31, no_bins=m.distogram_bins)
for (k, v) in structure_update
Expand All @@ -196,7 +201,7 @@ function infer(
masking_pattern = nothing,
num_recycles = nothing,
residue_index_offset::Int = 512,
chain_linker::AbstractString = "G"^25,
chain_linker::Union{AbstractString,Int} = "G"^25,
)
seqs = isa(sequences, AbstractString) ? [sequences] : sequences

Expand All @@ -221,21 +226,27 @@ function infer(
masking_pattern = to_device(masking_pattern, like, eltype(aatype))
end

aatype_jl = permutedims(aatype, (2, 1))
mask_jl = permutedims(mask, (2, 1))
residx_jl = permutedims(residx, (2, 1))
masking_pattern_jl = masking_pattern === nothing ? nothing : permutedims(masking_pattern, (2, 1))

output = m(
aatype;
mask = mask,
residx = residx,
masking_pattern = masking_pattern,
aatype_jl;
mask = mask_jl,
residx = residx_jl,
masking_pattern = masking_pattern_jl,
num_recycles = num_recycles,
)

output[:atom37_atom_exists] .*= reshape(linker_mask, size(linker_mask, 1), size(linker_mask, 2), 1)
output[:atom37_atom_exists] .*= reshape(permutedims(linker_mask, (2, 1)), 1, size(linker_mask, 2), size(linker_mask, 1))

weighted_plddt = output[:plddt] .* output[:atom37_atom_exists]
numerator = sum(weighted_plddt; dims=(2, 3))
denom = sum(output[:atom37_atom_exists]; dims=(2, 3))
atom37 = permutedims(output[:atom37_atom_exists], (2, 3, 1)) # (L, B, 37)
weighted_plddt = output[:plddt] .* atom37
numerator = sum(weighted_plddt; dims=(1, 3))
denom = sum(atom37; dims=(1, 3))
output[:mean_plddt] = numerator ./ denom
output[:chain_index] = chain_index
output[:chain_index] = permutedims(chain_index, (2, 1))

return output
end
Expand All @@ -244,13 +255,45 @@ function output_to_pdb(m::ESMFoldModel, output::AbstractDict)
return output_to_pdb(output)
end

function infer_pdbs(m::ESMFoldModel, seqs::AbstractVector{<:AbstractString}; kwargs...)
output = infer(m, seqs; kwargs...)
function infer_pdbs(
m::ESMFoldModel,
seqs::AbstractVector{<:AbstractString};
residx = nothing,
masking_pattern = nothing,
num_recycles = nothing,
residue_index_offset::Int = 512,
chain_linker::Union{AbstractString,Int} = "G"^25,
)
output = infer(
m,
seqs;
residx = residx,
masking_pattern = masking_pattern,
num_recycles = num_recycles,
residue_index_offset = residue_index_offset,
chain_linker = chain_linker,
)
return output_to_pdb(output)
end

function infer_pdb(m::ESMFoldModel, seq::AbstractString; kwargs...)
return infer_pdbs(m, [seq]; kwargs...)[1]
function infer_pdb(
m::ESMFoldModel,
seq::AbstractString;
residx = nothing,
masking_pattern = nothing,
num_recycles = nothing,
residue_index_offset::Int = 512,
chain_linker::Union{AbstractString,Int} = "G"^25,
)
return infer_pdbs(
m,
[seq];
residx = residx,
masking_pattern = masking_pattern,
num_recycles = num_recycles,
residue_index_offset = residue_index_offset,
chain_linker = chain_linker,
)[1]
end

function confidence_metrics(output::AbstractDict)
Expand Down
Loading
Loading