Skip to content

Commit 613c96c

Browse files
authored
Merge pull request #162 from Genentech/compare-motifs
Compare motifs
2 parents a30910e + 70f7fd8 commit 613c96c

File tree

4 files changed

+351
-211
lines changed

4 files changed

+351
-211
lines changed

docs/tutorials/5_variant.ipynb

Lines changed: 295 additions & 191 deletions
Large diffs are not rendered by default.

src/grelu/interpret/motifs.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def scan_sequences(
155155

156156
# Format motifs
157157
if isinstance(motifs, str):
158-
motifs = read_meme_file(motifs)
158+
motifs = read_meme_file(motifs, names=names)
159159

160160
import tempfile
161161

@@ -319,8 +319,6 @@ def compare_motifs(
319319
motifs: A dictionary whose values are Position Probability Matrices
320320
(PPMs) of shape (4, L), or the path to a MEME file.
321321
alt_seq: The alternate sequence as a string
322-
ref_allele: The alternate allele as a string. Only used if
323-
alt_seq is not supplied.
324322
alt_allele: The alternate allele as a string. Only needed if
325323
alt_seq is not supplied.
326324
pos: The position at which to substitute the alternate allele.
@@ -345,24 +343,39 @@ def compare_motifs(
345343
names=names,
346344
seq_ids=["ref", "alt"],
347345
pthresh=pthresh,
348-
rc=True, # Scan both strands
346+
rc=rc, # Scan both strands
349347
)
350-
351-
# Compare the results for alt and ref sequences
352-
scan = (
353-
scan.pivot_table(
354-
index=["motif", "start", "end", "strand"],
355-
columns=["sequence"],
356-
values="score",
348+
if len(scan) > 0:
349+
350+
# Compare the results for alt and ref sequences
351+
scan = (
352+
scan.pivot_table(
353+
index=["motif", "start", "end", "strand"],
354+
columns=["sequence"],
355+
values=["score", "p-value"],
356+
)
357+
.reset_index()
357358
)
358-
.fillna(0)
359-
.reset_index()
360-
)
361-
362-
# Compute fold change
363-
scan["foldChange"] = scan.alt / scan.ref
364-
scan = scan.sort_values("foldChange").reset_index(drop=True)
365-
return scan
359+
scan.columns = [col[0] if col[1] == '' else '_'.join(col) for col in scan.columns]
360+
for col in ["p-value_alt", "p-value_ref", "score_alt", "score_ref"]:
361+
if col not in scan.columns:
362+
scan[col] = np.nan
363+
364+
# Fill in empty positions
365+
for row in scan[scan.score_alt.isna()].itertuples():
366+
sc = scan_sequences(seqs=alt_seq[row.start:row.end+1], motifs=motifs, names=[row.motif], pthresh=1, rc=row.strand=='-').iloc[0]
367+
scan.loc[row.Index, 'score_alt'] = sc.score
368+
scan.loc[row.Index, 'p-value_alt'] = sc['p-value']
369+
370+
for row in scan[scan.score_ref.isna()].itertuples():
371+
sc = scan_sequences(seqs=ref_seq[row.start:row.end+1], motifs=motifs, names=[row.motif], pthresh=1, rc=row.strand=='-').iloc[0]
372+
scan.loc[row.Index, 'score_ref'] = sc.score
373+
scan.loc[row.Index, 'p-value_ref'] = sc['p-value']
374+
375+
# Compute fold change
376+
scan["score_diff"] = scan.score_alt - scan.score_ref
377+
scan = scan.sort_values("score_diff").reset_index(drop=True)
378+
return scan
366379

367380

368381
def run_tomtom(motifs: Dict[str, np.ndarray], meme_file: str) -> pd.DataFrame:

tests/test_interpret.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
run_tomtom,
1010
scan_sequences,
1111
trim_pwm,
12+
compare_motifs
1213
)
1314
from grelu.interpret.score import ISM_predict, get_attention_scores, get_attributions
1415
from grelu.interpret.simulate import (
@@ -268,6 +269,28 @@ def test_scan_sequences():
268269
assert out.equals(expected)
269270

270271

272+
def test_compare_motifs():
273+
out = compare_motifs(
274+
ref_seq="CACGTGACGCATGA",
275+
motifs=meme_file,
276+
alt_seq="TAAGTGACGCGTGA",
277+
pthresh = 5e-4,
278+
rc=False
279+
)
280+
expected = pd.DataFrame({
281+
'motif': ['MA0004.1 Arnt', 'MA0006.1 Ahr::Arnt'],
282+
'start': [0, 7],
283+
'end': [6, 13],
284+
'strand': ['+', '+'],
285+
'p-value_alt': [0.015625, 0.000244140625],
286+
'p-value_ref': [0.000244140625, 0.010009765624999995],
287+
'score_alt': [-14.648840188980103, 10.232005834579468],
288+
'score_ref': [11.60498046875, -2.9944558143615723],
289+
'score_diff': [-26.253820657730103, 13.22646164894104]
290+
})
291+
assert out.equals(expected)
292+
293+
271294
def test_run_tomtom():
272295

273296
motifs = {

tests/test_preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,5 +290,5 @@ def test_get_gc_matched_intervals():
290290
)
291291

292292
res = get_gc_matched_intervals(
293-
intervals=intervals, genome='hg38', chroms=['chr21'])
293+
intervals=intervals, genome='hg38', chroms=['chr21'], seed=0)
294294
assert len(res) == 1

0 commit comments

Comments
 (0)