Skip to content

Commit 1843dcb

Browse files
committed
Support fitting titer models per reference virus
Allow titer collections to request a titer model run per reference virus in each collection. The workflow aggregates the resulting model results into a single measurements panel JSON which allows users to display inferred titer measurements per virus in the tree in both the measurements panel and in the tree coloring by measurements for a reference selected from the panel. This logic attempts to recreate a nextflu feature which allows users to click on a titer reference virus to color the tree by the measurements against that virus and then choose to color by the titer model fit to that virus. Related to #214
1 parent 0be1114 commit 1843dcb

File tree

5 files changed

+280
-13
lines changed

5 files changed

+280
-13
lines changed

profiles/full-trees.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ builds:
116116
prefix: cell_hi_
117117
title: "Cell-passaged HI titers from ferret sera"
118118
genes: ["HA1"]
119+
run_reference_models: true
119120
- name: egg_hi
120121
data: "data/h1n1pdm/who_ferret_egg_hi_titers.tsv"
121122
prefix: egg_hi_

scripts/generate_collection_config_json.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,23 +71,26 @@ def calc_row(clade):
7171
}
7272

7373
# Read collection.
74-
collection_df = pd.read_csv(args.collection, sep="\t", usecols=args.groupings + ["reference_date", "clade_reference"])
74+
collection_df = pd.read_csv(args.collection, sep="\t", usecols=args.groupings)
7575

7676
# Map y-axis positions in the phylogeny to reference strains.
7777
collection_df["y_axis_position_in_phylogeny"] = collection_df["reference_strain"].map(y_axis_positions_per_tip_name)
7878

79-
# Find minimum y-axis position for reference strains within each clade. This
80-
# position represents the earliest instance of the clade in the tree.
81-
min_y_axis_position_by_reference_clade = collection_df.groupby("subclade_reference")["y_axis_position_in_phylogeny"].min().reset_index().rename(
82-
columns={"y_axis_position_in_phylogeny": "min_y_axis_position_in_phylogeny"}
83-
)
84-
85-
# Annotate min y-axis position per clade to collection.
86-
collection_df = collection_df.merge(
87-
min_y_axis_position_by_reference_clade,
88-
on="subclade_reference",
89-
how="left",
90-
)
79+
if "subclade_reference" in collection_df.columns:
80+
# Find minimum y-axis position for reference strains within each clade. This
81+
# position represents the earliest instance of the clade in the tree.
82+
min_y_axis_position_by_reference_clade = collection_df.groupby("subclade_reference")["y_axis_position_in_phylogeny"].min().reset_index().rename(
83+
columns={"y_axis_position_in_phylogeny": "min_y_axis_position_in_phylogeny"}
84+
)
85+
86+
# Annotate min y-axis position per clade to collection.
87+
collection_df = collection_df.merge(
88+
min_y_axis_position_by_reference_clade,
89+
on="subclade_reference",
90+
how="left",
91+
)
92+
else:
93+
collection_df["min_y_axis_position_in_phylogeny"] = collection_df["y_axis_position_in_phylogeny"]
9194

9295
# Sort collection by y-axis position.
9396
sorted_df = collection_df.sort_values(
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#!/usr/bin/env python3
2+
import argparse
3+
from augur.utils import read_node_data
4+
import pandas as pd
5+
6+
7+
if __name__ == '__main__':
8+
parser = argparse.ArgumentParser()
9+
parser.add_argument("--titer-model", required=True, help="node data JSON from titer model with inferred titers annotated in 'nodes' key by field ending with 'cTiterSub'")
10+
parser.add_argument("--titers", required=True, help="TSV of titers used to fit the given model")
11+
parser.add_argument("--annotations", nargs="+", help="additional annotations to add to the output table in the format of 'key=value' pairs")
12+
parser.add_argument("--output", required=True, help="table of antigenic distances in log2 titers between reference and test strains")
13+
14+
args = parser.parse_args()
15+
16+
# Load raw titers to get the reference name.
17+
raw_titers = pd.read_csv(
18+
args.titers,
19+
sep="\t",
20+
nrows=2,
21+
)
22+
reference = raw_titers["serum_strain"].values[0]
23+
24+
# Load titer model data.
25+
titer_data = read_node_data(args.titer_model)["nodes"]
26+
27+
# Convert titer data to a data frame.
28+
titer_records = []
29+
for test_strain, test_strain_values in titer_data.items():
30+
for key, value in test_strain_values.items():
31+
if key.endswith("cTiterSub"):
32+
titer_records.append({
33+
"reference_strain": reference,
34+
"test_strain": test_strain,
35+
"log2_titer": value,
36+
})
37+
38+
titer_table = pd.DataFrame(titer_records)
39+
40+
# Add any additional annotations requested by the user in the format of
41+
# "key=value" pairs where each key becomes a new column with the given
42+
# value.
43+
if args.annotations:
44+
for annotation in args.annotations:
45+
key, value = annotation.split("=")
46+
titer_table[key] = value
47+
48+
# Save the annotated table.
49+
titer_table.to_csv(
50+
args.output,
51+
sep="\t",
52+
index=False,
53+
float_format="%.4f"
54+
)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Split titers into separate files per reference virus.
2+
"""
3+
import argparse
4+
import pandas as pd
5+
6+
7+
if __name__ == "__main__":
8+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
9+
parser.add_argument("--titers", required=True, help="TSV of titers to split by reference")
10+
parser.add_argument("--output-references", required=True, help="text file listing the references with titer outputs")
11+
parser.add_argument("--output-titers-directory", required=True, help="directory where split titers TSV are placed per reference")
12+
13+
args = parser.parse_args()
14+
15+
titers = pd.read_csv(
16+
args.titers,
17+
sep="\t",
18+
)
19+
20+
# Find references with autologous and heterologous measurements.
21+
distinct_pairs = titers.loc[:, ["virus_strain", "serum_strain"]].drop_duplicates()
22+
print(f"Found {distinct_pairs.shape[0]} distinct pairs")
23+
24+
has_autologous_measurement = (distinct_pairs["virus_strain"] == distinct_pairs["serum_strain"])
25+
autologous_references = set(distinct_pairs.loc[has_autologous_measurement, "serum_strain"].drop_duplicates().values)
26+
print(f"Found {len(autologous_references)} autologous references")
27+
28+
has_heterologous_measurement = (distinct_pairs["virus_strain"] != distinct_pairs["serum_strain"])
29+
heterologous_references = set(distinct_pairs.loc[has_heterologous_measurement, "serum_strain"].drop_duplicates().values)
30+
print(f"Found {len(heterologous_references)} heterologous references")
31+
32+
selected_references = autologous_references & heterologous_references
33+
print(f"Found {len(selected_references)} references")
34+
35+
selected_titers = titers[titers["serum_strain"].isin(selected_references)].copy()
36+
selected_titers["reference_path"] = selected_titers["serum_strain"].apply(
37+
lambda strain: strain.replace("/", "_")
38+
)
39+
selected_reference_paths = selected_titers["reference_path"].drop_duplicates().values
40+
41+
for reference, reference_titers in selected_titers.groupby("reference_path"):
42+
reference_titers.to_csv(
43+
f"{args.output_titers_directory}/{reference}.tsv",
44+
sep="\t",
45+
index=False,
46+
)
47+
48+
with open(args.output_references, "w", encoding="utf-8") as oh:
49+
for reference in selected_reference_paths:
50+
print(reference, file=oh)

workflow/snakemake_rules/titer_models.smk

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,170 @@ rule export_measurements:
232232
--output-json {output.measurements} 2>&1 | tee {log}
233233
"""
234234

235+
checkpoint get_titers_per_reference:
236+
input:
237+
titers="builds/{build_name}/titers/{titer_collection}.tsv",
238+
output:
239+
references="builds/{build_name}/titer_references/{titer_collection}.txt",
240+
reference_titers_directory=directory("builds/{build_name}/reference_titers/{titer_collection}/"),
241+
conda: "../envs/nextstrain.yaml"
242+
shell:
243+
r"""
244+
mkdir -p {output.reference_titers_directory};
245+
246+
python scripts/get_titers_per_reference.py \
247+
--titers {input.titers} \
248+
--output-references {output.references} \
249+
--output-titers-directory {output.reference_titers_directory}
250+
"""
251+
252+
rule reference_model_titers_sub:
253+
input:
254+
titers = build_dir +"/{build_name}/reference_titers/{titer_collection}/{reference}.tsv",
255+
tree = rules.refine.output.tree,
256+
translations_done = build_dir + "/{build_name}/{segment}/translations.done"
257+
params:
258+
genes = get_titer_collection_genes,
259+
translations = lambda wildcards: [f"{build_dir}/{wildcards.build_name}/{wildcards.segment}/translations/{gene}_withInternalNodes.fasta" for gene in get_titer_collection_genes(wildcards)],
260+
attribute_prefix_argument = get_titer_collection_attribute_prefix_argument,
261+
output:
262+
titers_model = build_dir + "/{build_name}/{segment}/reference-titers-sub-model/{titer_collection}/{reference}.json",
263+
conda: "../envs/nextstrain.yaml"
264+
benchmark:
265+
"benchmarks/titers_sub_{build_name}_{segment}_{titer_collection}_{reference}.txt",
266+
log:
267+
"logs/titers_sub_{build_name}_{segment}_{titer_collection}_{reference}.txt",
268+
resources:
269+
mem_mb=8000,
270+
shell:
271+
"""
272+
augur titers sub \
273+
--titers {input.titers} \
274+
--alignment {params.translations} \
275+
--gene-names {params.genes} \
276+
--tree {input.tree} \
277+
--allow-empty-model \
278+
{params.attribute_prefix_argument} \
279+
--output {output.titers_model} 2>&1 | tee {log}
280+
"""
281+
282+
rule reference_model_antigenic_distances_between_strains:
283+
input:
284+
titer_model="builds/{build_name}/{segment}/reference-titers-sub-model/{titer_collection}/{reference}.json",
285+
titers="builds/{build_name}/reference_titers/{titer_collection}/{reference}.tsv",
286+
output:
287+
distances="builds/{build_name}/{segment}/reference_model_antigenic_distances_between_strains/{titer_collection}/{reference}.tsv",
288+
benchmark:
289+
"benchmarks/reference_model_antigenic_distances_between_strains_{build_name}_{segment}_{titer_collection}_{reference}.txt"
290+
log:
291+
"logs/reference_model_antigenic_distances_between_strains_{build_name}_{segment}_{titer_collection}_{reference}.txt"
292+
conda: "../envs/nextstrain.yaml"
293+
shell:
294+
"""
295+
python3 scripts/get_antigenic_distances_for_reference_model.py \
296+
--titer-model {input.titer_model} \
297+
--titers {input.titers} \
298+
--output {output.distances} &> {log}
299+
"""
300+
301+
def aggregate_reference_model_distances_input(wildcards):
302+
with checkpoints.get_titers_per_reference.get(**wildcards).output["references"].open() as fh:
303+
distances = [
304+
f"builds/{wildcards.build_name}/{wildcards.segment}/reference_model_antigenic_distances_between_strains/{wildcards.titer_collection}/{reference.strip()}.tsv"
305+
for reference in fh
306+
]
307+
308+
return distances
309+
310+
rule aggregate_reference_model_distances:
311+
input:
312+
distances=aggregate_reference_model_distances_input,
313+
output:
314+
distances="builds/{build_name}/{segment}/reference_model_antigenic_distances_between_strains/{titer_collection}.tsv",
315+
conda: "../envs/nextstrain.yaml"
316+
shell:
317+
r"""
318+
tsv-append -H {input.distances} > {output.distances}
319+
"""
320+
321+
rule generate_reference_model_collection_config_json:
322+
input:
323+
distances="builds/{build_name}/{segment}/reference_model_antigenic_distances_between_strains/{titer_collection}.tsv",
324+
tree="builds/{build_name}/{segment}/tree.nwk",
325+
output:
326+
config_json="builds/{build_name}/{segment}/reference_model_measurements_collection_config/{titer_collection}.json",
327+
conda: "../envs/nextstrain.yaml"
328+
params:
329+
groupings=[
330+
"reference_strain",
331+
],
332+
fields=[
333+
"strain",
334+
"reference_strain",
335+
"value",
336+
],
337+
log:
338+
"logs/generate_reference_model_collection_config_json_{build_name}_{segment}_{titer_collection}.txt"
339+
shell:
340+
"""
341+
python3 scripts/generate_collection_config_json.py \
342+
--tree {input.tree} \
343+
--collection {input.distances} \
344+
--groupings {params.groupings:q} \
345+
--fields {params.fields:q} \
346+
--output {output.config_json} &> {log}
347+
"""
348+
349+
rule export_reference_model_measurements:
350+
input:
351+
distances="builds/{build_name}/{segment}/reference_model_antigenic_distances_between_strains/{titer_collection}.tsv",
352+
configuration="builds/{build_name}/{segment}/reference_model_measurements_collection_config/{titer_collection}.json",
353+
output:
354+
measurements="builds/{build_name}/{segment}/reference_model_measurements/{titer_collection}.json",
355+
conda: "../envs/nextstrain.yaml"
356+
benchmark:
357+
"benchmarks/export_reference_model_measurements_{build_name}_{segment}_{titer_collection}.txt"
358+
log:
359+
"logs/export_reference_model_measurements_{build_name}_{segment}_{titer_collection}.txt"
360+
params:
361+
strain_column="test_strain",
362+
value_column="log2_titer",
363+
title=lambda wildcards: get_titer_collection_title(wildcards) + " (inferred)",
364+
x_axis_label="inferred log2 titer",
365+
thresholds=[0.0, 2.0],
366+
filters=[
367+
"reference_strain",
368+
],
369+
include_columns=[
370+
"reference_strain",
371+
],
372+
shell:
373+
"""
374+
augur measurements export \
375+
--collection {input.distances} \
376+
--collection-config {input.configuration} \
377+
--include-columns {params.include_columns:q} \
378+
--strain-column {params.strain_column} \
379+
--value-column {params.value_column} \
380+
--key {wildcards.titer_collection}_inferred \
381+
--title {params.title:q} \
382+
--x-axis-label {params.x_axis_label:q} \
383+
--thresholds {params.thresholds} \
384+
--filters {params.filters} \
385+
--show-threshold \
386+
--hide-overall-mean \
387+
--minify-json \
388+
--output-json {output.measurements} 2>&1 | tee {log}
389+
"""
390+
235391
def get_titer_collections(wildcards):
236392
files = []
237393
for collection in config["builds"][wildcards.build_name]["titer_collections"]:
238394
files.append(f"builds/{wildcards.build_name}/{wildcards.segment}/measurements/{collection['name']}.json")
239395

396+
if collection.get("run_reference_models"):
397+
files.append(f"builds/{wildcards.build_name}/{wildcards.segment}/reference_model_measurements/{collection['name']}.json")
398+
240399
return files
241400

242401
rule concat_measurements:

0 commit comments

Comments
 (0)