Skip to content

Commit 9b2a262

Browse files
committed
Update calculation of allele frequency-weighted distances
Replace Python script with PathoGenOmics/afwdist, a faster, lightweight tool
1 parent fccdae8 commit 9b2a262

File tree

5 files changed

+167
-173
lines changed

5 files changed

+167
-173
lines changed

workflow/envs/afwdist.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
channels:
2+
- bioconda
3+
- conda-forge
4+
dependencies:
5+
- afwdist==1.0.0

workflow/rules/distances.smk

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,52 @@
1-
rule weighted_distances:
2-
threads: 1
1+
rule extract_afwdist_variants:
32
conda: "../envs/biopython.yaml"
43
params:
5-
samples = expand("{sample}", sample = iter_samples()),
6-
mask_class = ["mask"]
4+
sample_col = "SAMPLE",
5+
position_col = "POS",
6+
sequence_col = "ALT",
7+
frequency_col = "ALT_FREQ",
8+
mask_class = ["mask"],
79
input:
8-
tsv = OUTDIR/f"{OUTPUT_NAME}.variants.tsv",
9-
vcf = lambda wildcards: select_problematic_vcf(),
10+
variants = OUTDIR/f"{OUTPUT_NAME}.variants.tsv",
11+
mask_vcf = lambda wildcards: select_problematic_vcf(),
1012
ancestor = OUTDIR/f"{OUTPUT_NAME}.ancestor.fasta",
11-
reference = OUTDIR/"reference.fasta"
13+
reference = OUTDIR/"reference.fasta",
1214
output:
13-
distances = REPORT_DIR_TABLES/f"distances.csv"
15+
variants = temp(OUTDIR/f"{OUTPUT_NAME}.variants.afwdist.csv"),
1416
log:
15-
LOGDIR / "weighted_distances" / "log.txt"
17+
LOGDIR/"extract_afwdist_variants"/"log.txt"
1618
script:
17-
"../scripts/weighted_distances.py"
19+
"../scripts/extract_afwdist_variants.py"
20+
21+
22+
rule afwdist_weighted_distances:
23+
conda: "../envs/afwdist.yaml"
24+
params:
25+
extra_args = ""
26+
input:
27+
variants = OUTDIR/f"{OUTPUT_NAME}.variants.afwdist.csv",
28+
reference = OUTDIR/f"{OUTPUT_NAME}.ancestor.fasta",
29+
output:
30+
distances = temp(REPORT_DIR_TABLES/"distances.raw.csv"),
31+
log:
32+
LOGDIR/"afwdist_weighted_distances"/"log.txt"
33+
shell:
34+
"afwdist "
35+
"-i {input.variants:q} "
36+
"-r {input.reference:q} "
37+
"-o {output.distances:q} "
38+
"{params.extra_args} >{log:q} 2>&1"
39+
40+
41+
rule format_afwdist_results:
42+
conda: "../envs/biopython.yaml"
43+
params:
44+
samples = sorted(iter_samples()),
45+
input:
46+
distances = REPORT_DIR_TABLES/"distances.raw.csv",
47+
output:
48+
distances = REPORT_DIR_TABLES/"distances.csv",
49+
log:
50+
LOGDIR/"format_afwdist_results"/"log.txt"
51+
script:
52+
"../scripts/format_afwdist_results.py"
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#!/usr/bin/env python3
2+
3+
import logging
4+
from typing import List
5+
6+
import pandas as pd
7+
from Bio import SeqIO
8+
from Bio.SeqRecord import SeqRecord
9+
from Bio.Seq import Seq
10+
11+
12+
def read_monofasta(path: str) -> SeqRecord:
13+
fasta = SeqIO.parse(path, "fasta")
14+
record = next(fasta)
15+
if next(fasta, None) is not None:
16+
logging.warning(f"There are unread records left in '{path}'")
17+
return record
18+
19+
20+
def read_masked_sites(vcf_path: str, mask_classes: List[str]) -> List[int]:
21+
"""
22+
Parse a VCF containing positions for masking. Assumes the VCF file is
23+
formatted as in:
24+
github.com/W-L/ProblematicSites_SARS-CoV2/blob/master/problematic_sites_sarsCov2.vcf
25+
with a "mask" or "caution" recommendation in column 7.
26+
Masked sites are specified with params.
27+
"""
28+
vcf = pd.read_csv(
29+
vcf_path,
30+
sep="\s+",
31+
comment="#",
32+
names=("CHROM", "POS", "ID", "REF", "ALT", "QUAL", "FILTER", "INFO")
33+
)
34+
return vcf.loc[vcf.FILTER.isin(mask_classes), "POS"].tolist()
35+
36+
37+
def build_ancestor_variant_table(ancestor: Seq, reference: Seq, reference_name: str, masked_positions: List[int]) -> pd.DataFrame:
38+
pos = []
39+
alt = []
40+
for i in range(1, len(ancestor) + 1):
41+
if i not in masked_positions and ancestor[i-1] != reference[i-1]:
42+
pos.append(i)
43+
alt.append(reference[i-1])
44+
df = pd.DataFrame({snakemake.params.position_col: pos, snakemake.params.sequence_col: alt})
45+
df[snakemake.params.frequency_col] = 1 # As a reference genome, we assume all positions have fixed alleles
46+
df[snakemake.params.sample_col] = reference_name
47+
return df
48+
49+
50+
if __name__ == "__main__":
51+
52+
logging.basicConfig(filename=snakemake.log[0], format=snakemake.config["LOG_PY_FMT"], level=logging.INFO)
53+
54+
colnames = {
55+
snakemake.params.sample_col: "sample",
56+
snakemake.params.position_col: "position",
57+
snakemake.params.sequence_col: "sequence",
58+
snakemake.params.frequency_col: "frequency"
59+
}
60+
61+
logging.info("Reading input tables")
62+
# Variants
63+
variants = pd.read_table(snakemake.input.variants, sep="\t")
64+
logging.info(f"Read {len(variants)} variant records")
65+
# VCF with sites to mask
66+
masked_sites = read_masked_sites(snakemake.input.mask_vcf, snakemake.params.mask_class)
67+
logging.info(f"Read {len(masked_sites)} masked positions")
68+
69+
logging.info("Reading input FASTA files")
70+
# Case ancestor
71+
ancestor = read_monofasta(snakemake.input.ancestor)
72+
logging.info(f"Ancestor: '{ancestor.description}', length={len(ancestor.seq)}")
73+
# Alignment reference
74+
reference = read_monofasta(snakemake.input.reference)
75+
logging.info(f"Reference: '{reference.description}', length={len(reference.seq)}")
76+
77+
logging.info("Processing ancestor variants")
78+
ancestor_table = build_ancestor_variant_table(ancestor.seq, reference.seq, reference.id, masked_sites)
79+
logging.info(f"Ancestor has {len(ancestor_table)} variants")
80+
all_variants = pd.concat([variants, ancestor_table], ignore_index=True)
81+
logging.info(f"Combined table has {len(all_variants)} variants")
82+
83+
logging.info("Renaming and selecting columns")
84+
output = all_variants.rename(columns=colnames)[list(colnames.values())]
85+
logging.info("Filtering sites")
86+
output = output[~output.position.isin(masked_sites)]
87+
logging.info(f"There are {len(output)} rows left")
88+
89+
logging.info("Writing results")
90+
output.to_csv(snakemake.output.variants, index=False)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#!/usr/bin/env python3
2+
3+
import logging
4+
import pandas as pd
5+
6+
7+
if __name__ == "__main__":
8+
9+
logging.basicConfig(filename=snakemake.log[0], format=snakemake.config["LOG_PY_FMT"], level=logging.INFO)
10+
11+
logging.info("Read pairwise distances")
12+
df = pd.read_csv(snakemake.input.distances)
13+
14+
logging.info("Initializing formatted output")
15+
output = pd.DataFrame(
16+
columns=snakemake.params.samples,
17+
index=snakemake.params.samples,
18+
dtype="float64"
19+
)
20+
21+
logging.info("Filling table")
22+
for i, row in df.iterrows():
23+
output.loc[row.sample_m, row.sample_n] = row.distance
24+
output.loc[row.sample_n, row.sample_m] = row.distance
25+
26+
logging.info("Writing formatted results")
27+
output.to_csv(snakemake.output.distances)

workflow/scripts/weighted_distances.py

Lines changed: 0 additions & 163 deletions
This file was deleted.

0 commit comments

Comments
 (0)