Skip to content

Commit b3f7d45

Browse files
committed
[titer download] viz for strain name matches
Visualises the stats from the matching approach in the previous commit
1 parent 0daccf1 commit b3f7d45

File tree

3 files changed

+257
-0
lines changed

3 files changed

+257
-0
lines changed

profiles/upload/upload.smk

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,22 @@ rule upload_all_titers:
2727
"data/upload/s3/titers/{build_name}/{titer_collection}.done".format(build_name=build_name, titer_collection=titer_collection["name"])
2828
for build_name, build_params in config["builds"].items()
2929
for titer_collection in build_params["titer_collections"]
30+
],
31+
titer_matching_viz=lambda wildcards: [
32+
f"data/upload/s3/titers/{lineage}_viz.done" for lineage in set([build['lineage'] for build in config["builds"].values()])
33+
]
34+
35+
36+
# Development-only rule which is the same as `upload_all_titers` without the actual uploading!
37+
rule dev_only_all_titers:
38+
input:
39+
titers=lambda wildcards: [
40+
"data/{build_name}/{titer_collection}_titers.tsv".format(build_name=build_name.split('_')[0], titer_collection=titer_collection["name"])
41+
for build_name, build_params in config["builds"].items()
42+
for titer_collection in build_params["titer_collections"]
43+
],
44+
titer_matching_viz=lambda wildcards: [
45+
f"data/{lineage}/titer-matches.png" for lineage in set([build['lineage'] for build in config["builds"].values()])
3046
]
3147

3248
rule upload_raw_sequences:
@@ -102,3 +118,19 @@ rule upload_titers:
102118
{input.titers:q} \
103119
{params.s3_dst:q}/{params.lineage}/{wildcards.titer_collection}_titers.tsv.gz 2>&1 | tee {output.flag}
104120
"""
121+
122+
rule upload_titer_match_viz:
123+
input:
124+
png="data/{lineage}/titer-matches.png",
125+
output:
126+
flag="data/upload/s3/titers/{lineage}_viz.done",
127+
params:
128+
s3_dst=config["s3_dst"],
129+
lineage="{lineage}"
130+
shell:
131+
"""
132+
./ingest/vendored/upload-to-s3 \
133+
--quiet \
134+
{input.png:q} \
135+
{params.s3_dst:q}/{params.lineage}/titer-matches.png 2>&1 | tee {output.flag}
136+
"""

scripts/titer-matching-viz.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
#!/usr/bin/env python3
2+
3+
import matplotlib.pyplot as plt
4+
import matplotlib
5+
import json
6+
from typing import TypedDict
7+
import sys
8+
import argparse
9+
import sys
10+
11+
YearStats = dict[str, tuple[int, int, int]] # {year: (matched, maybe_matched, total)}
12+
13+
class Stats(TypedDict):
14+
lineage: str
15+
center: str
16+
passage: str
17+
assay: str
18+
strains: YearStats
19+
measurements: YearStats
20+
21+
def read_json(fname:str)->Stats:
22+
with open(fname) as fh:
23+
return json.load(fh)
24+
25+
26+
27+
def order_years(stats: list[Stats]) -> tuple[list[str], list[str]]:
28+
"""
29+
Produces the x-values for visualisation so that we have consistency across plots
30+
"""
31+
observed: set[str] = set()
32+
xvals: list[str] = []
33+
xlabels: list[str] = []
34+
35+
for s in stats:
36+
observed.update(s['strains'].keys())
37+
observed.update(s['measurements'].keys())
38+
39+
observed.remove('total')
40+
41+
if 'unknown' in observed:
42+
observed.remove('unknown')
43+
xvals.append('unknown')
44+
xlabels.append('unknown')
45+
# add empty x-value/label to space the unknown a little more to the left
46+
xvals.append('<dummy>')
47+
xlabels.append('')
48+
try:
49+
min_cat = [v for v in observed if v.startswith('<=')][0]
50+
observed.remove(min_cat)
51+
xvals.append(min_cat)
52+
xlabels.append(min_cat)
53+
except IndexError:
54+
pass
55+
56+
ordered = sorted(observed)
57+
years = [str(year) for year in range(int(ordered[0]), int(ordered[-1])+1)]
58+
59+
xvals.extend(years)
60+
xlabels.extend([y if idx!=0 and idx%5==0 else '' for idx,y in enumerate(years) ])
61+
xlabels[-1] = xvals[-1] # ensure last (most recent) value is shown
62+
return (xvals, xlabels)
63+
64+
65+
def plot_matches(stats: list[Stats], output_file: str):
66+
"""
67+
Create small-multiples visualization of fauna matching percentages over time.
68+
69+
Creates one subplot per subtype, with each contributor as a separate line.
70+
Two rows: fauna matches (top) and curated matches (bottom)
71+
X-axis: years
72+
Y-axis: percentage of strains matched (a/b*100)
73+
"""
74+
75+
lineages = set([s['lineage'] for s in stats])
76+
assert len(lineages)==1
77+
lineage = list(lineages)[0]
78+
centers = list(set([s['center'] for s in stats]))
79+
stats_per_center = [[s for s in stats if s['center']==center] for center in centers]
80+
max_stats_per_center = max([len(l) for l in stats_per_center])
81+
xvals, xlabels = order_years(stats)
82+
x_indices = list(range(len(xvals)))
83+
84+
c_strain = '#c51b8a'
85+
c_measurement = '#2c7fb8'
86+
87+
fig, axes = plt.subplots(len(centers), max_stats_per_center, figsize=(max_stats_per_center*5, len(centers) * 4), squeeze=False)
88+
89+
circle_scalar = 1/10
90+
circla_alpha = 0.5
91+
92+
for row_idx, center in enumerate(centers):
93+
for col_idx, data in enumerate(stats_per_center[row_idx]):
94+
ax = axes[row_idx, col_idx]
95+
96+
percentages = {
97+
'strains': {x:data['strains'][x][0]/data['strains'][x][2]*100 for x in xvals if x in data['strains']},
98+
'strains_maybe': {x:data['strains'][x][1]/data['strains'][x][2]*100 for x in xvals if x in data['strains']},
99+
'measurements': {x:data['measurements'][x][0]/data['measurements'][x][2]*100 for x in xvals if x in data['measurements']},
100+
'measurements_maybe': {x:data['measurements'][x][1]/data['measurements'][x][2]*100 for x in xvals if x in data['measurements']},
101+
}
102+
103+
# Extract y-values and sizes in the order of years
104+
strains_y = [percentages['strains'].get(x) for x in xvals]
105+
strains_maybe_y = [percentages['strains_maybe'].get(x) for x in xvals] # maybe: potential matches
106+
strains_sizes = [data['strains'][x][0]*circle_scalar if x in data['strains'] else 0 for x in xvals]
107+
strains_maybe_sizes = [data['strains'][x][1]*circle_scalar if x in data['strains'] else 0 for x in xvals]
108+
measurements_y = [percentages['measurements'].get(x) for x in xvals]
109+
measurements_maybe_y = [percentages['measurements_maybe'].get(x) for x in xvals] # maybe: potential matches
110+
measurements_sizes = [data['measurements'][x][0]*circle_scalar if x in data['measurements'] else 0 for x in xvals]
111+
measurements_maybe_sizes = [data['measurements'][x][1]*circle_scalar if x in data['measurements'] else 0 for x in xvals]
112+
113+
# Maybe (potential) matches of strains
114+
ax.plot(x_indices, strains_maybe_y,
115+
linewidth=0.5, linestyle='--', color=c_strain)
116+
ax.scatter(x_indices, strains_maybe_y,
117+
s=strains_maybe_sizes, color=c_strain, alpha=circla_alpha, zorder=5, clip_on=False)
118+
119+
# True matches of strains
120+
ax.plot(x_indices, strains_y,
121+
linewidth=0.5, color=c_strain)
122+
ax.scatter(x_indices, strains_y,
123+
s=strains_sizes, color=c_strain, alpha=circla_alpha, zorder=5, clip_on=False)
124+
125+
# Maybe (potential) matches of measurements
126+
ax.plot(x_indices, measurements_maybe_y,
127+
linewidth=0.5, linestyle='--', color=c_measurement)
128+
ax.scatter(x_indices, measurements_maybe_y,
129+
s=measurements_maybe_sizes, color=c_measurement, alpha=circla_alpha, zorder=5, clip_on=False)
130+
131+
# True matches of measurements
132+
ax.plot(x_indices, measurements_y,
133+
linewidth=0.5, color=c_measurement)
134+
ax.scatter(x_indices, measurements_y,
135+
s=measurements_sizes, color=c_measurement, alpha=circla_alpha, zorder=5, clip_on=False)
136+
137+
# Add horizontal dashed lines for overall percentages
138+
for idx in [0,1]:
139+
total_strains = data['strains']['total']
140+
total_measurements = data['measurements']['total']
141+
pct_strains = total_strains[idx] / total_strains[2] * 100
142+
pct_measurements = total_measurements[idx] / total_measurements[2] * 100
143+
144+
if idx==0:
145+
ax.axhline(pct_strains, linestyle='-', color=c_strain, alpha=0.7)
146+
ax.axhline(pct_measurements, linestyle='-', color=c_measurement, alpha=0.7)
147+
# Add text labels for total percentages in bottom left
148+
ax.text(0.02, 0.42, f'Unique strains matched:', transform=ax.transAxes, color=c_strain, fontsize=12)
149+
ax.text(0.02, 0.34, f'{total_strains[0]:,} / {total_strains[2]:,} ({pct_strains:.1f}%)', transform=ax.transAxes, color=c_strain, fontsize=12)
150+
ax.text(0.02, 0.18, f'Unique measurements matched:', transform=ax.transAxes, color=c_measurement, fontsize=12)
151+
ax.text(0.02, 0.10, f'{total_measurements[0]:,} / {total_measurements[2]:,} ({pct_measurements:.1f}%)', transform=ax.transAxes, color=c_measurement, fontsize=12)
152+
if idx==1:
153+
ax.text(0.02, 0.26, f'maybes = {total_strains[1]:,} ({pct_strains:.1f}%)', transform=ax.transAxes, color=c_strain, fontsize=12)
154+
ax.text(0.02, 0.02, f'maybes = {total_measurements[1]:,} ({pct_measurements:.1f}%)', transform=ax.transAxes, color=c_measurement, fontsize=12)
155+
156+
ax.set_xticks(x_indices)
157+
ax.set_xticklabels(xlabels, rotation=45, ha='right')
158+
ax.set_xlim(-0.5, len(x_indices) - 0.5)
159+
160+
if col_idx == 0:
161+
ax.set_ylabel(center, fontsize=16)
162+
163+
ax.set_title(f"{lineage} | {center} | {data['passage']} | {data['assay']}")
164+
ax.set_ylim(0, 100)
165+
ax.spines['right'].set_visible(False)
166+
ax.spines['top'].set_visible(False)
167+
168+
# Add figure title
169+
fig.suptitle(f"Lineage: {lineage}", fontsize=16, fontweight='bold')
170+
171+
# Adjust layout to prevent overlap
172+
plt.tight_layout(rect=[0, 0, 1, 0.98])
173+
174+
# Save as high-resolution PNG
175+
plt.savefig(output_file, dpi=300, bbox_inches='tight')
176+
print(f"\nFigure saved to {output_file}", file=sys.stderr)
177+
178+
if __name__ == '__main__':
179+
parser = argparse.ArgumentParser(description=__doc__)
180+
parser.add_argument("--stats", nargs="+", required=True, metavar='JSON', help="stats JSONs")
181+
parser.add_argument("--output", required=True, metavar='PNG', help="output viz")
182+
args = parser.parse_args()
183+
184+
try:
185+
stats = [read_json(s) for s in sorted(args.stats)]
186+
plot_matches(stats, args.output)
187+
except Exception as e:
188+
print(f"Visualisation script failed with error:", file=sys.stderr)
189+
print(e, file=sys.stderr)
190+
print(f"Exiting with code 0 so that automated pipelines don't die", file=sys.stderr)
191+
with open(args.output, 'w') as fh:
192+
print("Script failed!", file=fh)
193+

workflow/snakemake_rules/download_from_fauna.smk

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,35 @@ rule select_titers_by_host:
195195
"""
196196
tsv-filter -H {params.host_query} {input.titers} > {output.titers} 2> {log}
197197
"""
198+
199+
200+
def _stats(wildcards):
201+
"""
202+
Collect all the relevant (titer-matching) stats JSONs for the requested wildcards.lineage
203+
"""
204+
stats = set()
205+
matching_builds = [build for build in config["builds"].values() if build['lineage']==wildcards.lineage]
206+
for build in matching_builds:
207+
for collection in build['titer_collections']:
208+
center, _host, passage, assay = collection['name'].split('_')
209+
stats.add(f"data/{wildcards.lineage}/{center}_{passage}_{assay}_matching-stats.json")
210+
return sorted(list(stats))
211+
212+
rule visualise_titer_matches:
213+
"""
214+
Visualise the titer matche stats produced by the remap_titer_strain_names
215+
rule above. This script will _always_ exit 0 and produce the output file
216+
to avoid viz errors terminating the pipeline
217+
"""
218+
input:
219+
stats = _stats
220+
output:
221+
png = "data/{lineage}/titer-matches.png"
222+
conda: "../envs/nextstrain.yaml"
223+
log:
224+
"logs/visualise_titer_matches_{lineage}.txt"
225+
shell:
226+
"""
227+
./scripts/titer-matching-viz.py --stats {input.stats} --output {output.png} 2> {log}
228+
"""
229+

0 commit comments

Comments
 (0)