|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import matplotlib |
| 5 | +import json |
| 6 | +from typing import TypedDict |
| 7 | +import sys |
| 8 | +import argparse |
| 9 | +import sys |
| 10 | + |
| 11 | +YearStats = dict[str, tuple[int, int, int]] # {year: (matched, maybe_matched, total)} |
| 12 | + |
| 13 | +class Stats(TypedDict): |
| 14 | + lineage: str |
| 15 | + center: str |
| 16 | + passage: str |
| 17 | + assay: str |
| 18 | + strains: YearStats |
| 19 | + measurements: YearStats |
| 20 | + |
| 21 | +def read_json(fname:str)->Stats: |
| 22 | + with open(fname) as fh: |
| 23 | + return json.load(fh) |
| 24 | + |
| 25 | + |
| 26 | + |
| 27 | +def order_years(stats: list[Stats]) -> tuple[list[str], list[str]]: |
| 28 | + """ |
| 29 | + Produces the x-values for visualisation so that we have consistency across plots |
| 30 | + """ |
| 31 | + observed: set[str] = set() |
| 32 | + xvals: list[str] = [] |
| 33 | + xlabels: list[str] = [] |
| 34 | + |
| 35 | + for s in stats: |
| 36 | + observed.update(s['strains'].keys()) |
| 37 | + observed.update(s['measurements'].keys()) |
| 38 | + |
| 39 | + observed.remove('total') |
| 40 | + |
| 41 | + if 'unknown' in observed: |
| 42 | + observed.remove('unknown') |
| 43 | + xvals.append('unknown') |
| 44 | + xlabels.append('unknown') |
| 45 | + # add empty x-value/label to space the unknown a little more to the left |
| 46 | + xvals.append('<dummy>') |
| 47 | + xlabels.append('') |
| 48 | + try: |
| 49 | + min_cat = [v for v in observed if v.startswith('<=')][0] |
| 50 | + observed.remove(min_cat) |
| 51 | + xvals.append(min_cat) |
| 52 | + xlabels.append(min_cat) |
| 53 | + except IndexError: |
| 54 | + pass |
| 55 | + |
| 56 | + ordered = sorted(observed) |
| 57 | + years = [str(year) for year in range(int(ordered[0]), int(ordered[-1])+1)] |
| 58 | + |
| 59 | + xvals.extend(years) |
| 60 | + xlabels.extend([y if idx!=0 and idx%5==0 else '' for idx,y in enumerate(years) ]) |
| 61 | + xlabels[-1] = xvals[-1] # ensure last (most recent) value is shown |
| 62 | + return (xvals, xlabels) |
| 63 | + |
| 64 | + |
| 65 | +def plot_matches(stats: list[Stats], output_file: str): |
| 66 | + """ |
| 67 | + Create small-multiples visualization of fauna matching percentages over time. |
| 68 | +
|
| 69 | + Creates one subplot per subtype, with each contributor as a separate line. |
| 70 | + Two rows: fauna matches (top) and curated matches (bottom) |
| 71 | + X-axis: years |
| 72 | + Y-axis: percentage of strains matched (a/b*100) |
| 73 | + """ |
| 74 | + |
| 75 | + lineages = set([s['lineage'] for s in stats]) |
| 76 | + assert len(lineages)==1 |
| 77 | + lineage = list(lineages)[0] |
| 78 | + centers = list(set([s['center'] for s in stats])) |
| 79 | + stats_per_center = [[s for s in stats if s['center']==center] for center in centers] |
| 80 | + max_stats_per_center = max([len(l) for l in stats_per_center]) |
| 81 | + xvals, xlabels = order_years(stats) |
| 82 | + x_indices = list(range(len(xvals))) |
| 83 | + |
| 84 | + c_strain = '#c51b8a' |
| 85 | + c_measurement = '#2c7fb8' |
| 86 | + |
| 87 | + fig, axes = plt.subplots(len(centers), max_stats_per_center, figsize=(max_stats_per_center*5, len(centers) * 4), squeeze=False) |
| 88 | + |
| 89 | + circle_scalar = 1/10 |
| 90 | + circla_alpha = 0.5 |
| 91 | + |
| 92 | + for row_idx, center in enumerate(centers): |
| 93 | + for col_idx, data in enumerate(stats_per_center[row_idx]): |
| 94 | + ax = axes[row_idx, col_idx] |
| 95 | + |
| 96 | + percentages = { |
| 97 | + 'strains': {x:data['strains'][x][0]/data['strains'][x][2]*100 for x in xvals if x in data['strains']}, |
| 98 | + 'strains_maybe': {x:data['strains'][x][1]/data['strains'][x][2]*100 for x in xvals if x in data['strains']}, |
| 99 | + 'measurements': {x:data['measurements'][x][0]/data['measurements'][x][2]*100 for x in xvals if x in data['measurements']}, |
| 100 | + 'measurements_maybe': {x:data['measurements'][x][1]/data['measurements'][x][2]*100 for x in xvals if x in data['measurements']}, |
| 101 | + } |
| 102 | + |
| 103 | + # Extract y-values and sizes in the order of years |
| 104 | + strains_y = [percentages['strains'].get(x) for x in xvals] |
| 105 | + strains_maybe_y = [percentages['strains_maybe'].get(x) for x in xvals] # maybe: potential matches |
| 106 | + strains_sizes = [data['strains'][x][0]*circle_scalar if x in data['strains'] else 0 for x in xvals] |
| 107 | + strains_maybe_sizes = [data['strains'][x][1]*circle_scalar if x in data['strains'] else 0 for x in xvals] |
| 108 | + measurements_y = [percentages['measurements'].get(x) for x in xvals] |
| 109 | + measurements_maybe_y = [percentages['measurements_maybe'].get(x) for x in xvals] # maybe: potential matches |
| 110 | + measurements_sizes = [data['measurements'][x][0]*circle_scalar if x in data['measurements'] else 0 for x in xvals] |
| 111 | + measurements_maybe_sizes = [data['measurements'][x][1]*circle_scalar if x in data['measurements'] else 0 for x in xvals] |
| 112 | + |
| 113 | + # Maybe (potential) matches of strains |
| 114 | + ax.plot(x_indices, strains_maybe_y, |
| 115 | + linewidth=0.5, linestyle='--', color=c_strain) |
| 116 | + ax.scatter(x_indices, strains_maybe_y, |
| 117 | + s=strains_maybe_sizes, color=c_strain, alpha=circla_alpha, zorder=5, clip_on=False) |
| 118 | + |
| 119 | + # True matches of strains |
| 120 | + ax.plot(x_indices, strains_y, |
| 121 | + linewidth=0.5, color=c_strain) |
| 122 | + ax.scatter(x_indices, strains_y, |
| 123 | + s=strains_sizes, color=c_strain, alpha=circla_alpha, zorder=5, clip_on=False) |
| 124 | + |
| 125 | + # Maybe (potential) matches of measurements |
| 126 | + ax.plot(x_indices, measurements_maybe_y, |
| 127 | + linewidth=0.5, linestyle='--', color=c_measurement) |
| 128 | + ax.scatter(x_indices, measurements_maybe_y, |
| 129 | + s=measurements_maybe_sizes, color=c_measurement, alpha=circla_alpha, zorder=5, clip_on=False) |
| 130 | + |
| 131 | + # True matches of measurements |
| 132 | + ax.plot(x_indices, measurements_y, |
| 133 | + linewidth=0.5, color=c_measurement) |
| 134 | + ax.scatter(x_indices, measurements_y, |
| 135 | + s=measurements_sizes, color=c_measurement, alpha=circla_alpha, zorder=5, clip_on=False) |
| 136 | + |
| 137 | + # Add horizontal dashed lines for overall percentages |
| 138 | + for idx in [0,1]: |
| 139 | + total_strains = data['strains']['total'] |
| 140 | + total_measurements = data['measurements']['total'] |
| 141 | + pct_strains = total_strains[idx] / total_strains[2] * 100 |
| 142 | + pct_measurements = total_measurements[idx] / total_measurements[2] * 100 |
| 143 | + |
| 144 | + if idx==0: |
| 145 | + ax.axhline(pct_strains, linestyle='-', color=c_strain, alpha=0.7) |
| 146 | + ax.axhline(pct_measurements, linestyle='-', color=c_measurement, alpha=0.7) |
| 147 | + # Add text labels for total percentages in bottom left |
| 148 | + ax.text(0.02, 0.42, f'Unique strains matched:', transform=ax.transAxes, color=c_strain, fontsize=12) |
| 149 | + ax.text(0.02, 0.34, f'{total_strains[0]:,} / {total_strains[2]:,} ({pct_strains:.1f}%)', transform=ax.transAxes, color=c_strain, fontsize=12) |
| 150 | + ax.text(0.02, 0.18, f'Unique measurements matched:', transform=ax.transAxes, color=c_measurement, fontsize=12) |
| 151 | + ax.text(0.02, 0.10, f'{total_measurements[0]:,} / {total_measurements[2]:,} ({pct_measurements:.1f}%)', transform=ax.transAxes, color=c_measurement, fontsize=12) |
| 152 | + if idx==1: |
| 153 | + ax.text(0.02, 0.26, f'maybes = {total_strains[1]:,} ({pct_strains:.1f}%)', transform=ax.transAxes, color=c_strain, fontsize=12) |
| 154 | + ax.text(0.02, 0.02, f'maybes = {total_measurements[1]:,} ({pct_measurements:.1f}%)', transform=ax.transAxes, color=c_measurement, fontsize=12) |
| 155 | + |
| 156 | + ax.set_xticks(x_indices) |
| 157 | + ax.set_xticklabels(xlabels, rotation=45, ha='right') |
| 158 | + ax.set_xlim(-0.5, len(x_indices) - 0.5) |
| 159 | + |
| 160 | + if col_idx == 0: |
| 161 | + ax.set_ylabel(center, fontsize=16) |
| 162 | + |
| 163 | + ax.set_title(f"{lineage} | {center} | {data['passage']} | {data['assay']}") |
| 164 | + ax.set_ylim(0, 100) |
| 165 | + ax.spines['right'].set_visible(False) |
| 166 | + ax.spines['top'].set_visible(False) |
| 167 | + |
| 168 | + # Add figure title |
| 169 | + fig.suptitle(f"Lineage: {lineage}", fontsize=16, fontweight='bold') |
| 170 | + |
| 171 | + # Adjust layout to prevent overlap |
| 172 | + plt.tight_layout(rect=[0, 0, 1, 0.98]) |
| 173 | + |
| 174 | + # Save as high-resolution PNG |
| 175 | + plt.savefig(output_file, dpi=300, bbox_inches='tight') |
| 176 | + print(f"\nFigure saved to {output_file}", file=sys.stderr) |
| 177 | + |
| 178 | +if __name__ == '__main__': |
| 179 | + parser = argparse.ArgumentParser(description=__doc__) |
| 180 | + parser.add_argument("--stats", nargs="+", required=True, metavar='JSON', help="stats JSONs") |
| 181 | + parser.add_argument("--output", required=True, metavar='PNG', help="output viz") |
| 182 | + args = parser.parse_args() |
| 183 | + |
| 184 | + try: |
| 185 | + stats = [read_json(s) for s in sorted(args.stats)] |
| 186 | + plot_matches(stats, args.output) |
| 187 | + except Exception as e: |
| 188 | + print(f"Visualisation script failed with error:", file=sys.stderr) |
| 189 | + print(e, file=sys.stderr) |
| 190 | + print(f"Exiting with code 0 so that automated pipelines don't die", file=sys.stderr) |
| 191 | + with open(args.output, 'w') as fh: |
| 192 | + print("Script failed!", file=fh) |
| 193 | + |
0 commit comments