Skip to content

Commit 3497f5e

Browse files
Merge pull request #572 from SamStudio8/cw-6236
Reduce memory usage of SV postprocessing
2 parents 6c51a53 + 8b50060 commit 3497f5e

File tree

2 files changed

+156
-70
lines changed

2 files changed

+156
-70
lines changed

src/sniffles/postprocessing.py

Lines changed: 116 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,34 @@
88
# Maintainer: Hermann Romanek
99
# Contact: sniffles@romanek.at
1010
#
11-
from functools import partial
12-
from typing import Callable
11+
from enum import IntEnum
12+
import math
13+
import logging
1314

1415
from sniffles import util
1516
from sniffles import consensus
1617
from sniffles.config import SnifflesConfig
1718
from sniffles.sv import SVCall
18-
import math
19+
20+
log = logging.getLogger('sniffles.postprocessing')
21+
22+
class CoverageMode(IntEnum):
23+
"""A cheap(er) class to represent the coverage mode used when setting SVCall attributes."""
24+
START = 1
25+
CENTER = 2
26+
END = 3
27+
UPSTREAM = 4
28+
DOWNSTREAM = 5
29+
30+
@property
31+
def to_attr(self):
32+
return {
33+
CoverageMode.START: "coverage_start",
34+
CoverageMode.CENTER: "coverage_center",
35+
CoverageMode.END: "coverage_end",
36+
CoverageMode.UPSTREAM: "coverage_upstream",
37+
CoverageMode.DOWNSTREAM: "coverage_downstream",
38+
}[self]
1939

2040

2141
def annotate_sv(svcall: SVCall, config):
@@ -62,29 +82,19 @@ def annotate_sv(svcall: SVCall, config):
6282
svcall.alt = best_lead.seq
6383

6484

65-
def add_request(svcall, field, pos, requests_for_coverage: dict[int, list[Callable]], config):
85+
def add_request(svcall_i: int, field: CoverageMode, pos, requests_for_coverage: dict[int, set[tuple[int, int]]], config):
6686
"""
6787
Add a request for one of the five coverage fields to the given SVCall
6888
"""
6989
bin_index = int(pos / config.coverage_binsize) * config.coverage_binsize
7090
if bin_index not in requests_for_coverage:
71-
requests_for_coverage[bin_index] = []
72-
requests_for_coverage[bin_index].append(partial(setattr, svcall, field))
73-
74-
75-
def add_sampling_point_request(svcall, pos, requests_for_coverage: dict[int, list[Callable]], config):
76-
"""
77-
Add a coverage request for a dynamic sampling point
78-
"""
79-
bin_index = int(pos / config.coverage_binsize) * config.coverage_binsize
80-
if bin_index not in requests_for_coverage:
81-
requests_for_coverage[bin_index] = []
82-
requests_for_coverage[bin_index].append(partial(svcall.add_coverage_sample, bin_index))
91+
requests_for_coverage[bin_index] = set()
92+
requests_for_coverage[bin_index].add( (svcall_i, field) )
8393

8494

8595
def coverage(calls, lead_provider, config):
86-
requests_for_coverage = coverage_build_requests(calls, config)
87-
return coverage_fulfill(requests_for_coverage, lead_provider, config)
96+
requests_for_coverage_attrs, requests_for_coverage_sample_starts, requests_for_coverage_sample_ends = coverage_build_requests(calls, config)
97+
return coverage_fulfill(calls, requests_for_coverage_attrs, requests_for_coverage_sample_starts, requests_for_coverage_sample_ends, lead_provider, config)
8898

8999

90100
def coverage_build_requests(calls, config: SnifflesConfig):
@@ -98,8 +108,11 @@ def coverage_build_requests(calls, config: SnifflesConfig):
98108
INV U--- S--- E--- D---
99109
C---
100110
"""
101-
requests_for_coverage = {}
102-
for svcall in calls:
111+
requests_for_coverage_attrs = {}
112+
requests_for_coverage_sample_starts = {}
113+
requests_for_coverage_sample_ends = {}
114+
115+
for sv_i, svcall in enumerate(calls):
103116
start = svcall.pos
104117
if svcall.svtype == "INS":
105118
end = start + 1
@@ -108,33 +121,46 @@ def coverage_build_requests(calls, config: SnifflesConfig):
108121

109122
if svcall.svtype in ("DEL", 'INV', 'DUP') and abs(svcall.svlen) >= config.long_del_length:
110123
# Sampling more intervals for large deletions
111-
for loc in range(start, end, config.large_coverage_sample_interval):
112-
add_sampling_point_request(svcall, loc, requests_for_coverage, config)
124+
clean_start = start
125+
if start < 0:
126+
# Excusing a negative start, we'll find the first positive position we'd encounter striding our large coverage interval and start there instead
127+
clean_start = start + ((-start // config.large_coverage_sample_interval) + 1) * config.large_coverage_sample_interval
128+
log.warning(f"[sv] Encountered SV={svcall.id} with negative pos={start}. Setting to first viable positive pos={clean_start} for coverage postprocessing.")
129+
first_bin = (clean_start // config.coverage_binsize) * config.coverage_binsize
130+
last_bin_pos = clean_start + ((end - clean_start - 1) // config.large_coverage_sample_interval) * config.large_coverage_sample_interval
131+
last_bin = (last_bin_pos // config.coverage_binsize) * config.coverage_binsize
132+
133+
if first_bin not in requests_for_coverage_sample_starts:
134+
requests_for_coverage_sample_starts[first_bin] = set()
135+
if last_bin not in requests_for_coverage_sample_ends:
136+
requests_for_coverage_sample_ends[last_bin] = set()
137+
requests_for_coverage_sample_starts[first_bin].add(sv_i)
138+
requests_for_coverage_sample_ends[last_bin].add(sv_i)
113139

114140
if svcall.svtype in ["INS", "BND"]:
115-
add_request(svcall, "coverage_start", start - config.coverage_binsize, requests_for_coverage, config)
116-
add_request(svcall, "coverage_center", int((start + end - config.coverage_binsize) / 2),
117-
requests_for_coverage, config)
118-
add_request(svcall, "coverage_end", end + config.coverage_binsize, requests_for_coverage, config)
119-
add_request(svcall, "coverage_upstream", start - config.coverage_binsize * config.coverage_updown_bins,
120-
requests_for_coverage, config)
121-
add_request(svcall, "coverage_downstream", end + config.coverage_binsize * config.coverage_updown_bins,
122-
requests_for_coverage, config)
141+
add_request(sv_i, CoverageMode.START, start - config.coverage_binsize, requests_for_coverage_attrs, config)
142+
add_request(sv_i, CoverageMode.CENTER, int((start + end - config.coverage_binsize) / 2),
143+
requests_for_coverage_attrs, config)
144+
add_request(sv_i, CoverageMode.END, end + config.coverage_binsize, requests_for_coverage_attrs, config)
145+
add_request(sv_i, CoverageMode.UPSTREAM, start - config.coverage_binsize * config.coverage_updown_bins,
146+
requests_for_coverage_attrs, config)
147+
add_request(sv_i, CoverageMode.DOWNSTREAM, end + config.coverage_binsize * config.coverage_updown_bins,
148+
requests_for_coverage_attrs, config)
123149
else:
124-
add_request(svcall, "coverage_start", start, requests_for_coverage, config)
125-
add_request(svcall, "coverage_center", int((start + end - config.coverage_binsize) / 2),
126-
requests_for_coverage, config)
127-
add_request(svcall, "coverage_end", end - config.coverage_binsize, requests_for_coverage, config)
128-
add_request(svcall, "coverage_upstream", start - config.coverage_binsize * config.coverage_updown_bins,
129-
requests_for_coverage, config)
130-
add_request(svcall, "coverage_downstream", end + config.coverage_binsize * config.coverage_updown_bins,
131-
requests_for_coverage, config)
150+
add_request(sv_i, CoverageMode.START, start, requests_for_coverage_attrs, config)
151+
add_request(sv_i, CoverageMode.CENTER, int((start + end - config.coverage_binsize) / 2),
152+
requests_for_coverage_attrs, config)
153+
add_request(sv_i, CoverageMode.END, end - config.coverage_binsize, requests_for_coverage_attrs, config)
154+
add_request(sv_i, CoverageMode.UPSTREAM, start - config.coverage_binsize * config.coverage_updown_bins,
155+
requests_for_coverage_attrs, config)
156+
add_request(sv_i, CoverageMode.DOWNSTREAM, end + config.coverage_binsize * config.coverage_updown_bins,
157+
requests_for_coverage_attrs, config)
132158

133-
return requests_for_coverage
159+
return requests_for_coverage_attrs, requests_for_coverage_sample_starts, requests_for_coverage_sample_ends
134160

135161

136-
def coverage_fulfill(requests_for_coverage, lead_provider, config: SnifflesConfig):
137-
if len(requests_for_coverage) == 0:
162+
def coverage_fulfill(calls, requests_for_coverage_attrs, requests_for_coverage_sample_starts, requests_for_coverage_sample_ends, lead_provider, config: SnifflesConfig):
163+
if len(requests_for_coverage_attrs) == 0 and len(requests_for_coverage_sample_starts) == 0:
138164
return -1, -1
139165

140166
start_bin = lead_provider.covrtab_min_bin
@@ -145,19 +171,57 @@ def coverage_fulfill(requests_for_coverage, lead_provider, config: SnifflesConfi
145171
coverage_rev_total = 0
146172
n = 0
147173

148-
for bin_index in range(start_bin, end_bin + config.coverage_binsize, config.coverage_binsize):
174+
# We can efficiently store and query for large coverage sampling intervals (LCSI) by
175+
# maintaining a set for each possible bin modulus which stores the indices of SVCalls
176+
# that are open to be sampled at all bins with the same modulus.
177+
# The diagram below illustrates a series of bins and how this approach would work.
178+
# Consider an LCSI every 4 bins (actual positions ignored for simplicity).
179+
# Here SV_i starts at a position between bin 1 and 2 and ends at a position
180+
# between bin 8 and 9. As the coverage postprocessing forces sampling to start and
181+
# end on a coverage bin, sampling of SV_i will start at bin 1 and end at bin 8.
182+
# Bin 1 has a modulus of 1 with respect to the LCSI of 4, as does 5, 9 and so on.
183+
# For all the possible moduli (in this case, [0, 4]) we can maintain a set of
184+
# SVCall indices that are open to be sampled at the next bin with the same modulus.
185+
#
186+
# id 0 1 2 3 4 5 6 7 8 9 10 11 12 ..
187+
# id%4 0 1 2 3 0 1 2 3 4 1 2 3 0 ...
188+
# Bins o----o----o----o----o----o----o----o----o----o----o----o----o ...
189+
# LCSI%0 X-------------------X-------------------X-------------------X ...
190+
# SV_i ===S-------------------------------X==E
191+
# ^ First Bin has modulus 1 ^ Last Bin
192+
# we'll sample all modulus 1 bins before last bin
193+
# LCSI%1 X-------------------X-------------------X-------------->...
194+
# SV_i LCSI X X
195+
bin_moduli = {}
196+
197+
for bin_pos in range(start_bin, end_bin + config.coverage_binsize, config.coverage_binsize):
198+
bin_modulus = bin_pos % config.large_coverage_sample_interval
149199
n += 1
150200

151-
if bin_index in lead_provider.covrtab_fwd:
152-
coverage_fwd += lead_provider.covrtab_fwd[bin_index]
153-
154-
if bin_index in lead_provider.covrtab_rev:
155-
coverage_rev += lead_provider.covrtab_rev[bin_index]
156-
157-
if bin_index in requests_for_coverage:
158-
coverage_total_curr = coverage_fwd + coverage_rev
159-
for set_coverage_fn in requests_for_coverage[bin_index]:
160-
set_coverage_fn(coverage_total_curr)
201+
coverage_fwd += lead_provider.covrtab_fwd.get(bin_pos, 0) # TODO -- should this be reset
202+
coverage_rev += lead_provider.covrtab_rev.get(bin_pos, 0) # TODO -- should this be reset
203+
coverage_total_curr = coverage_fwd + coverage_rev
204+
205+
# handle attr style coverage requests for this bin position
206+
for sv_i, mode in requests_for_coverage_attrs.get(bin_pos, []):
207+
setattr(calls[sv_i], mode.to_attr, coverage_total_curr)
208+
209+
# open
210+
# if this bin position opens SVs, add each SVCall index to the set of SVs to be
211+
# updated for any bin with the same modulo
212+
if bin_pos in requests_for_coverage_sample_starts:
213+
if bin_modulus not in bin_moduli:
214+
bin_moduli[bin_modulus] = set()
215+
bin_moduli[bin_modulus].update(requests_for_coverage_sample_starts[bin_pos])
216+
217+
# sample welfordly
218+
for sv_i in bin_moduli.get(bin_modulus, []):
219+
calls[sv_i].forward_difference_sampler.push(coverage_total_curr)
220+
221+
# close
222+
# this is the last bin, do not process any more overlaps on the next iteration
223+
if bin_pos in requests_for_coverage_sample_ends:
224+
bin_moduli[bin_modulus].difference_update(requests_for_coverage_sample_ends[bin_pos])
161225

162226
coverage_fwd_total += coverage_fwd
163227
coverage_rev_total += coverage_rev

src/sniffles/sv.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99
# Contact: sniffles@romanek.at
1010
#
1111
import logging
12-
from dataclasses import dataclass
12+
from dataclasses import dataclass, field
1313
from typing import Optional, Callable
1414

15-
import numpy as np
16-
1715
try:
1816
from edlib import align
1917
except ImportError:
@@ -38,6 +36,42 @@ class SVCallPostprocessingInfo:
3836
cluster: list
3937

4038

39+
class ForwardDifferenceWelford:
40+
def __init__(self):
41+
self.n = 0
42+
self.m1 = 0
43+
self.m2 = 0
44+
self.last = None
45+
46+
def push(self, value):
47+
# first observation just sets the last seen observation
48+
# we'll calculate forward difference from here
49+
if self.last is None:
50+
self.last = value
51+
return
52+
last = self.last; m = self.n # every little helps?
53+
v = (value - last) / (last + 1e-10) # TODO epsilon to avoid division by zero as before - is this appropriate?
54+
n = m + 1
55+
delta = v - self.m1
56+
delta_n = delta / n
57+
self.m1 += delta_n
58+
self.m2 += delta * delta_n * m
59+
self.n = n
60+
self.last = value
61+
62+
@property
63+
def mean(self):
64+
if self.n == 0:
65+
return None
66+
return self.m1
67+
68+
@property
69+
def variance(self):
70+
if self.n < 2:
71+
return None
72+
return self.m2 / self.n # ddof=0
73+
74+
4175
@dataclass
4276
class SVCall:
4377
contig: str
@@ -67,12 +101,12 @@ class SVCall:
67101
fwd: int = None
68102
rev: int = None
69103

104+
forward_difference_sampler: ForwardDifferenceWelford = field(default_factory=ForwardDifferenceWelford)
70105
coverage_upstream: int = 0
71106
coverage_downstream: int = 0
72107
coverage_start: int = 0
73108
coverage_center: int = 0
74109
coverage_end: int = 0
75-
coverage_samples: dict = None
76110

77111
sample_internal_id: int = None
78112
bnd_info: SVCallBNDInfo = None
@@ -94,24 +128,12 @@ def has_info(self, k):
94128
def finalize(self):
95129
self.postprocess = None
96130

97-
def add_coverage_sample(self, pos: int, coverage: int):
98-
if self.coverage_samples is None:
99-
self.coverage_samples = {}
100-
self.coverage_samples[pos] = coverage
101-
102131
def qc_coverage_samples(self) -> tuple[bool, float | None]:
103-
"""
104-
Check if coverage sampling indicates a QC pass. Returns True if this Call passes QC, False otherwise.
105-
"""
106-
if not self.coverage_samples:
132+
var = self.forward_difference_sampler.variance
133+
if var is None:
107134
return True, None
108-
109-
samples = np.fromiter(self.coverage_samples.values(), int)
110-
diffs = np.diff(samples) / (samples[:-1] + 1e-10) # epsilon to avoid division by zero
111-
var = np.var(diffs)
112135
return var < 0.3, float(var)
113136

114-
115137
@dataclass
116138
class SVGroup:
117139
"""

0 commit comments

Comments
 (0)