Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 53 additions & 10 deletions src/metfish/analysis/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from biopandas.pdb import PandasPdb
from typing import List, Dict, Tuple, Optional

from metfish.utils import get_lddt, save_aligned_pdb, get_Pr
from metfish.utils import get_rmsd, get_lddt, save_aligned_pdb, get_Pr


class ProteinStructureAnalyzer:
Expand All @@ -35,6 +35,9 @@ def calculate_metrics(self,

# Calculate SAXS KL divergence
saxs_kldiv = self._calculate_saxs_kldiv(p_of_r_a, p_of_r_b, r_a, r_b)

# Calculate SAXS L1 loss
saxs_l1 = self._calculate_saxs_l1(p_of_r_a, p_of_r_b, r_a, r_b)

# calculate radius of gyration
md_obj_a = md.load(fname_a)
Expand All @@ -55,6 +58,7 @@ def calculate_metrics(self,
'rmsd': rmsd,
'lddt': lddt,
'saxs_kldiv': saxs_kldiv,
'saxs_l1': saxs_l1,
'rg_a': rg_a,
'rg_b': rg_b,
'rg_diff': rg_diff,
Expand Down Expand Up @@ -86,6 +90,25 @@ def _calculate_saxs_kldiv(self,
saxs_b_padded = (saxs_b_padded + eps) / np.sum(saxs_b_padded + eps)

return np.sum(rel_entr(saxs_a_padded, saxs_b_padded))

def _calculate_saxs_l1(self,
p_of_r_a: np.ndarray,
p_of_r_b: np.ndarray,
r_a: np.ndarray,
r_b: np.ndarray) -> float:
"""Calculate L1 loss (MAE) between SAXS profiles."""
max_len = max(len(r_a), len(r_b))

# Pad both arrays to same length
saxs_a_padded = np.pad(p_of_r_a, (0, max_len - len(r_a)),
mode='constant', constant_values=0)
saxs_b_padded = np.pad(p_of_r_b, (0, max_len - len(r_b)),
mode='constant', constant_values=0)

# Calculate L1 loss (sum of absolute differences)
return np.sum(np.abs(saxs_a_padded - saxs_b_padded))



class ModelComparisonProcessor:
"""Process and create comparison dataframes for multiple models."""
Expand Down Expand Up @@ -118,15 +141,35 @@ def _get_apo_holo_pairs(self) -> List[Tuple[str, str]]:
return None

def _create_comparisons_list(self) -> List[Tuple[str, str]]:
comparisons=[('out', 'target'), ('out', 'out_alt'), ('target', 'target_alt')]
all_comparisons = [
tuple(c.replace('out', f'out_{t}') for c in comp)
for comp in comparisons
if any('out' == x or 'out_alt' == x for x in comp)
for t in self.tags
]
all_comparisons.extend([tuple(f'out_{x}' for x in comb) for comb in itertools.combinations(self.tags, 2)])
return all_comparisons
"""
Generate all comparison pairs for protein structure analysis.

Returns:
List of (structure_a, structure_b) tuples representing all comparisons:
- 3 model vs target comparisons (AF vs target, NMR vs target, NMA vs target)
- 3 model vs alternative structure comparisons (AF vs AF_alt, NMR vs NMR_alt, NMA vs NMA_alt)
- 1 target vs alternative target comparison (target vs target_alt)
- 3 cross-model comparisons (AF vs NMR, AF vs NMA, NMR vs NMA)
Total: 10 comparisons per protein
"""
comparisons = []

# 1. Each model vs target (3 comparisons)
for model_tag in self.tags.keys():
comparisons.append((f'out_{model_tag}', 'target'))

# 2. Each model vs its alternative structure (3 comparisons)
for model_tag in self.tags.keys():
comparisons.append((f'out_{model_tag}', f'out_alt_{model_tag}'))

# 3. Target vs alternative target (1 comparison)
comparisons.append(('target', 'target_alt'))

# 4. Cross-model comparisons (3 comparisons: AF vs NMR, AF vs NMA, NMR vs NMA)
for model_a, model_b in itertools.combinations(self.tags.keys(), 2):
comparisons.append((f'out_{model_a}', f'out_{model_b}'))

return comparisons

def get_comparison_df(self,
names: List[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/metfish/analysis/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,4 +263,4 @@ def _plot_protein_saxs_summary(self, name: str, ax: plt.Axes):
ax.set_ylabel('P(r)', fontsize=8)
ax.set_title(name, fontsize=9, fontweight='bold')
sns.despine(ax=ax)


Loading