1- using NNlib
1+ # OpenFold inference utilities.
2+ # Atom14/atom37 data tables and confidence metrics come from ProtInterop.
3+ # This file keeps only ESMFold-specific wrappers.
24
3- const _restype_atom14_to_atom37 = let
4- rows = Vector {Vector{Int}} ()
5- for rt in restypes
6- atom_names = restype_name_to_atom14_names[restype_1to3[rt]]
7- push! (rows, [name == " " ? 1 : (atom_order[name] + 1 ) for name in atom_names])
8- end
9- push! (rows, fill (1 , 14 )) # UNK
10- reduce (vcat, (reshape (r, 1 , :) for r in rows))
11- end
5+ using ProtInterop: OF_RESTYPE_ATOM14_TO_ATOM37, OF_RESTYPE_ATOM37_TO_ATOM14,
6+ OF_RESTYPE_ATOM14_MASK, OF_RESTYPE_ATOM37_MASK
127
13- const _restype_atom37_to_atom14 = let
14- rows = Vector {Vector{Int}} ()
15- for rt in restypes
16- atom_names = restype_name_to_atom14_names[restype_1to3[rt]]
17- atom_name_to_idx14 = Dict (name => i for (i, name) in enumerate (atom_names) if name != " " )
18- push! (rows, [get (atom_name_to_idx14, name, 1 ) for name in atom_types])
19- end
20- push! (rows, fill (1 , length (atom_types))) # UNK
21- reduce (vcat, (reshape (r, 1 , :) for r in rows))
22- end
8+ # ── Aliases for backward compatibility ──────────────────────────────────────
239
24- const _restype_atom14_mask = let
25- rows = Vector {Vector{Float32}} ()
26- for rt in restypes
27- atom_names = restype_name_to_atom14_names[restype_1to3[rt]]
28- push! (rows, [name == " " ? 0f0 : 1f0 for name in atom_names])
29- end
30- push! (rows, fill (0f0 , 14 )) # UNK
31- reduce (vcat, (reshape (r, 1 , :) for r in rows))
32- end
10+ const _restype_atom14_to_atom37 = OF_RESTYPE_ATOM14_TO_ATOM37
11+ const _restype_atom37_to_atom14 = OF_RESTYPE_ATOM37_TO_ATOM14
12+ const _restype_atom14_mask = OF_RESTYPE_ATOM14_MASK
13+ const _restype_atom37_mask = OF_RESTYPE_ATOM37_MASK
3314
34- const _restype_atom37_mask = let
35- mask = zeros (Float32, length (restypes), length (atom_types))
36- for (restype, restype_letter) in enumerate (restypes)
37- restype_name = restype_1to3[restype_letter]
38- atom_names = residue_atoms[restype_name]
39- for atom_name in atom_names
40- atom_type = atom_order[atom_name] + 1
41- mask[restype, atom_type] = 1f0
42- end
43- end
44- mask = vcat (mask, zeros (Float32, 1 , length (atom_types))) # UNK
45- mask
46- end
15+ # ── make_atom14_masks! (ESMFold-specific: mutating API, modifies Dict) ──────
4716
4817function make_atom14_masks! (protein:: AbstractDict )
4918 protein_aatype = protein[:aatype ] .+ 1
@@ -67,82 +36,16 @@ function make_atom14_masks!(protein::AbstractDict)
6736 return protein
6837end
6938
70- function _calculate_bin_centers (boundaries:: AbstractArray )
71- step = boundaries[2 ] - boundaries[1 ]
72- bin_centers = boundaries .+ step / 2
73- return vcat (bin_centers, bin_centers[end ] + step)
74- end
39+ # ── _calculate_expected_aligned_error (ESMFold-specific helper) ─────────────
7540
7641function _calculate_expected_aligned_error (alignment_confidence_breaks, aligned_distance_error_probs)
77- bin_centers = _calculate_bin_centers (alignment_confidence_breaks)
42+ bin_centers = ProtInterop . _calculate_bin_centers (alignment_confidence_breaks)
7843 bview = reshape (bin_centers, ntuple (_ -> 1 , ndims (aligned_distance_error_probs) - 1 )... , length (bin_centers))
7944 expected = sum (aligned_distance_error_probs .* bview; dims= ndims (aligned_distance_error_probs))
8045 expected = dropdims (expected; dims= ndims (expected))
8146 return expected, bin_centers[end ]
8247end
8348
84- function compute_predicted_aligned_error (
85- logits;
86- max_bin:: Int = 31 ,
87- no_bins:: Int = 64 ,
88- )
89- boundaries_cpu = collect (range (0f0 , Float32 (max_bin); length= no_bins - 1 ))
90- bin_centers_cpu = _calculate_bin_centers (boundaries_cpu)
91-
92- aligned_confidence_probs = NNlib. softmax (logits; dims= 1 )
93- bin_centers = to_device (bin_centers_cpu, logits, Float32)
94- bview = reshape (bin_centers, length (bin_centers), ntuple (_ -> 1 , ndims (aligned_confidence_probs) - 1 )... )
95- expected = sum (aligned_confidence_probs .* bview; dims= 1 )
96- expected = dropdims (expected; dims= 1 )
97-
98- return Dict (
99- :aligned_confidence_probs => aligned_confidence_probs,
100- :predicted_aligned_error => expected,
101- :max_predicted_aligned_error => bin_centers_cpu[end ],
102- )
103- end
104-
105- function compute_tm (
106- logits;
107- residue_weights = nothing ,
108- asym_id = nothing ,
109- interface:: Bool = false ,
110- max_bin:: Int = 31 ,
111- no_bins:: Int = 64 ,
112- eps:: Real = 1e-8 ,
113- )
114- if residue_weights === nothing
115- residue_weights = ones_like (logits, size (logits, 2 ))
116- end
117-
118- boundaries = range (0f0 , Float32 (max_bin); length= no_bins - 1 )
119- boundaries = to_device (collect (boundaries), logits, Float32)
120- bin_centers = _calculate_bin_centers (boundaries)
121-
122- clipped_n = max (sum (residue_weights), 19 )
123- d0 = 1.24f0 * (clipped_n - 15 )^ (1f0 / 3f0 ) - 1.8f0
124-
125- probs = NNlib. softmax (logits; dims= 1 )
126- tm_per_bin = 1f0 ./ (1f0 .+ (bin_centers .^ 2 ) ./ (d0 ^ 2 ))
127- tm_view = reshape (tm_per_bin, length (tm_per_bin), ntuple (_ -> 1 , ndims (probs) - 1 )... )
128- predicted_tm_term = sum (probs .* tm_view; dims= 1 )
129- predicted_tm_term = dropdims (predicted_tm_term; dims= 1 )
130-
131- n = size (predicted_tm_term, 1 )
132- pair_mask = ones_like (predicted_tm_term, Int, n, n)
133- if interface && asym_id != = nothing
134- pair_mask .= (reshape (asym_id, :, 1 ) .!= reshape (asym_id, 1 , :))
135- end
136-
137- predicted_tm_term = predicted_tm_term .* pair_mask
138-
139- pair_residue_weights = pair_mask .* (reshape (residue_weights, :, 1 ) .* reshape (residue_weights, 1 , :))
140- denom = eps .+ sum (pair_residue_weights; dims= 2 )
141- normed_residue_mask = pair_residue_weights ./ denom
142- per_alignment = sum (predicted_tm_term .* normed_residue_mask; dims= 2 )
143- per_alignment = dropdims (per_alignment; dims= 2 )
144-
145- weighted = per_alignment .* residue_weights
146- max_idx = argmax (weighted)
147- return per_alignment[max_idx]
148- end
49+ # Confidence metrics (compute_plddt, compute_predicted_aligned_error, compute_tm)
50+ # are now imported from ProtInterop via `using ProtInterop` in the module definition.
51+ # ESMFold re-exports them.
0 commit comments