Skip to content

Commit 60e5317

Browse files
committed
Added validation scripts
1 parent 26347ca commit 60e5317

File tree

2 files changed

+705
-0
lines changed

2 files changed

+705
-0
lines changed
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Validate bioassembly metadata CSV file.
4+
5+
This script validates that:
6+
1. The sequence field is properly reconstructed using stoichiometry and all_sequences
7+
2. Each chain's copies in stoichiometry matches the number of chains in FASTA description
8+
3. The sequence is a correct concatenation: chain_id_sequence1*copies1 + chain_id_sequence2*copies2 + ...
9+
"""
10+
11+
import csv
12+
from functools import partial
13+
from pathlib import Path
14+
import re
15+
import sys
16+
from typing import Dict, List, Tuple
17+
from dataclasses import dataclass
18+
import traceback
19+
20+
try:
21+
from biotite import sequence as bioseq
22+
23+
has_biotite = True
24+
except ImportError:
25+
has_biotite = False
26+
# Import parse_fasta from the tools/fasta directory
27+
sys.path.insert(0, str(Path(__file__).parent / "tools/fasta"))
28+
from chain_parser import parse_fasta
29+
30+
csv.field_size_limit(sys.maxsize)
31+
32+
33+
@dataclass
34+
class ValidationResult:
35+
"""Result of validation for a single row."""
36+
37+
pdb_id: str
38+
is_valid: bool
39+
errors: List[str]
40+
warnings: List[str] = None
41+
42+
43+
def parse_stoichiometry(stoichiometry: str) -> List[Tuple[str, int]]:
44+
"""
45+
Parse stoichiometry string into list of (chain_id, copies) tuples.
46+
47+
Args:
48+
stoichiometry: String like "A:2" or "A:1;B:1" or "B:60;C:60"
49+
50+
Returns:
51+
List of tuples: [("A", 2)] or [("A", 1), ("B", 1)]
52+
"""
53+
if not stoichiometry:
54+
return []
55+
56+
result = []
57+
for part in stoichiometry.split(";"):
58+
chain_id, copies = part.split(":")
59+
result.append((chain_id.strip(), int(copies.strip())))
60+
61+
return result
62+
63+
64+
def validate_row(
65+
pdb_id: str, stoichiometry: str, sequence: str, all_sequences: str
66+
) -> ValidationResult:
67+
"""
68+
Validate a single row of the CSV.
69+
70+
Args:
71+
pdb_id: PDB ID
72+
stoichiometry: Stoichiometry string
73+
sequence: Expected concatenated sequence
74+
all_sequences: FASTA formatted sequences
75+
76+
Returns:
77+
ValidationResult object
78+
"""
79+
errors = []
80+
warnings = []
81+
# Parse stoichiometry
82+
try:
83+
stoich_list = parse_stoichiometry(stoichiometry)
84+
except Exception as e:
85+
errors.append(f"Failed to parse stoichiometry: {e}\n{traceback.format_exc()}")
86+
return ValidationResult(pdb_id, False, errors)
87+
88+
# Parse FASTA
89+
try:
90+
fasta_dict = parse_fasta(all_sequences)
91+
except Exception as e:
92+
errors.append(f"Failed to parse FASTA: {e}\n{traceback.format_exc()}")
93+
return ValidationResult(pdb_id, False, errors)
94+
95+
# Validate each chain in stoichiometry
96+
reconstructed_sequence = []
97+
98+
for chain_id, copies in stoich_list:
99+
# Check if chain exists in FASTA
100+
if chain_id not in fasta_dict:
101+
errors.append(f"Chain '{chain_id}' from stoichiometry not found in FASTA")
102+
continue
103+
104+
chain_sequence, chain_list = fasta_dict[chain_id]
105+
106+
# Validate that copies matches the number of chains in FASTA description
107+
expected_copies = len(chain_list)
108+
if copies != expected_copies:
109+
warnings.append(
110+
f"Chain '{chain_id}': stoichiometry says {copies} copies, "
111+
f"but FASTA description lists {expected_copies} chains: {', '.join(chain_list)}"
112+
)
113+
114+
# Add to reconstructed sequence
115+
reconstructed_sequence.append(chain_sequence * copies)
116+
117+
# Reconstruct the full sequence
118+
reconstructed = "".join(reconstructed_sequence)
119+
120+
# Compare with provided sequence
121+
if reconstructed != sequence:
122+
# Allow mismatch with X in reconstructed sequence
123+
matches = [i == j for i, j in zip(sequence, reconstructed) if j != "X"]
124+
if all(matches) and len(sequence) == len(reconstructed):
125+
pass # Acceptable mismatch due to 'X'
126+
else:
127+
# Calculate alignment using biotite
128+
# RNASequence = partial(bioseq.GeneralSequence, bioseq.Alphabet("ACGU"))
129+
# seq1 = RNASequence(sequence)
130+
# seq2 = RNASequence(reconstructed)
131+
if has_biotite:
132+
try:
133+
sequence_t = sequence.replace("U", "T")
134+
reconstructed_t = reconstructed.replace("U", "T").replace("X", "N")
135+
seq1 = bioseq.NucleotideSequence(sequence_t)
136+
seq2 = bioseq.NucleotideSequence(reconstructed_t)
137+
alignments = bioseq.align.align_optimal(
138+
seq1,
139+
seq2,
140+
matrix=bioseq.align.SubstitutionMatrix.std_nucleotide_matrix(),
141+
gap_penalty=-5,
142+
terminal_penalty=False,
143+
)
144+
best_alignment = str(alignments[0]).replace("T", "U")
145+
except Exception as e:
146+
best_alignment = f"Error computing alignment: {e}"
147+
else:
148+
best_alignment = "Biotite library not installed; alignment unavailable."
149+
errors.append(
150+
f"Sequence mismatch:\n"
151+
f" Expected length: {len(sequence)}\n"
152+
f" Reconstructed length: {len(reconstructed)}\n"
153+
f" Expected: {sequence}\n"
154+
f" Reconstructed: {reconstructed}\n"
155+
f" Is partial match: {sequence in reconstructed or reconstructed in sequence}\n"
156+
f" Alignment:\n{best_alignment}\n"
157+
)
158+
159+
is_valid = len(errors) == 0
160+
return ValidationResult(pdb_id, is_valid, errors, warnings)
161+
162+
163+
def validate_csv_file(
164+
filepath: str, verbose: bool = False
165+
) -> Tuple[int, int, List[ValidationResult]]:
166+
"""
167+
Validate entire CSV file.
168+
169+
Args:
170+
filepath: Path to CSV file
171+
verbose: If True, print details for all rows; if False, only print errors
172+
173+
Returns:
174+
Tuple of (total_rows, valid_rows, list_of_failed_validations)
175+
"""
176+
total_rows = 0
177+
valid_rows = 0
178+
failed_validations = []
179+
180+
with open(filepath, "r", encoding="utf-8") as f:
181+
reader = csv.DictReader(f)
182+
183+
for row in reader:
184+
total_rows += 1
185+
if "target_id" not in row:
186+
target_id = row["pdb_id"]
187+
else:
188+
target_id = row["target_id"]
189+
190+
stoichiometry = row["stoichiometry"]
191+
sequence = row["sequence"]
192+
all_sequences = row["all_sequences"]
193+
194+
result = validate_row(target_id, stoichiometry, sequence, all_sequences)
195+
196+
if result.is_valid:
197+
valid_rows += 1
198+
if verbose:
199+
print(f"✓ {target_id}: VALID")
200+
for warning in result.warnings:
201+
print(f" - WARNING: {warning}")
202+
print()
203+
else:
204+
failed_validations.append(result)
205+
print(f"✗ {target_id}: INVALID")
206+
for error in result.errors:
207+
print(f" - {error}")
208+
for warning in result.warnings:
209+
print(f" - WARNING: {warning}")
210+
print()
211+
212+
return total_rows, valid_rows, failed_validations
213+
214+
215+
def main():
216+
"""Main entry point."""
217+
import argparse
218+
219+
parser = argparse.ArgumentParser(
220+
description="Validate bioassembly metadata CSV file",
221+
formatter_class=argparse.RawDescriptionHelpFormatter,
222+
epilog="""
223+
Examples:
224+
%(prog)s bioassembly_metadata.csv
225+
%(prog)s bioassembly_metadata.csv --verbose
226+
%(prog)s bioassembly_metadata.csv --summary-only
227+
""",
228+
)
229+
parser.add_argument("csv_file", help="Path to CSV file to validate")
230+
parser.add_argument(
231+
"-v",
232+
"--verbose",
233+
action="store_true",
234+
help="Print details for all rows (not just errors)",
235+
)
236+
parser.add_argument(
237+
"-s",
238+
"--summary-only",
239+
action="store_true",
240+
help="Only print summary statistics",
241+
)
242+
243+
args = parser.parse_args()
244+
245+
print(f"Validating {args.csv_file}...")
246+
print()
247+
248+
# Validate
249+
if args.summary_only:
250+
# Suppress individual error output
251+
import io
252+
from contextlib import redirect_stdout
253+
254+
with redirect_stdout(io.StringIO()):
255+
total, valid, failed = validate_csv_file(args.csv_file, args.verbose)
256+
else:
257+
total, valid, failed = validate_csv_file(args.csv_file, args.verbose)
258+
259+
# Print summary
260+
print("=" * 70)
261+
print("VALIDATION SUMMARY")
262+
print("=" * 70)
263+
print(f"Total rows validated: {total}")
264+
print(f"Valid rows: {valid}")
265+
print(f"Invalid rows: {len(failed)}")
266+
267+
if total > 0:
268+
success_rate = (valid / total) * 100
269+
print(f"Success rate: {success_rate:.2f}%")
270+
271+
if failed:
272+
print()
273+
print(f"Failed PDB IDs: {', '.join([r.pdb_id for r in failed])}")
274+
275+
# Exit with appropriate code
276+
sys.exit(0 if len(failed) == 0 else 1)
277+
278+
279+
if __name__ == "__main__":
280+
main()

0 commit comments

Comments
 (0)