Skip to content

Commit bb64d4e

Browse files
authored
Merge pull request #138 from Genentech/improve-test
updated test scan sequences
2 parents b15f244 + 3daeb25 commit bb64d4e

File tree

1 file changed

+41
-22
lines changed

1 file changed

+41
-22
lines changed

tests/test_interpret.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -215,38 +215,57 @@ def test_get_attention_scores():
215215

216216

217217
def test_scan_sequences():
218-
seqs = ["TCACGTGAA", "CCTGCGTGA", "CACGCAGGA"]
218+
seqs = ["TCACGTGAA", "CACGCAGGA", "CCTGCGTGA"]
219219

220220
# No reverse complement
221221
out = scan_sequences(seqs, motifs=meme_file, rc=False, pthresh=1e-3)
222-
assert out.motif.tolist() == ["MA0004.1 Arnt", "MA0006.1 Ahr::Arnt"]
223-
assert out.sequence.tolist() == ["0", "1"]
224-
assert out.start.tolist() == [1, 2]
225-
assert out.end.tolist() == [7, 8]
226-
assert out.strand.tolist() == ["+", "+"]
227-
assert out.matched_seq.tolist() == ["CACGTG", "TGCGTG"]
222+
expected = pd.DataFrame({
223+
'motif': ['MA0004.1 Arnt', 'MA0006.1 Ahr::Arnt'],
224+
'sequence': ['0', '2'],
225+
'seq_idx': [0, 2],
226+
'start': [1, 2],
227+
'end': [7, 8],
228+
'strand': ['+', '+'],
229+
'score': [11.60498046875, 10.691319823265076],
230+
'p-value': [0.000244140625, 0.000244140625],
231+
'matched_seq': ['CACGTG', 'TGCGTG']
232+
})
233+
assert out.equals(expected)
228234

229235
# Allow reverse complement
230236
out = scan_sequences(seqs, motifs=meme_file, rc=True, pthresh=1e-3)
231-
assert out.motif.tolist() == [
232-
"MA0004.1 Arnt",
233-
"MA0004.1 Arnt",
234-
"MA0006.1 Ahr::Arnt",
235-
"MA0006.1 Ahr::Arnt",
236-
]
237-
assert out.sequence.tolist() == ["0", "0", "1", "2"]
238-
assert out.start.tolist() == [1, 1, 2, 0]
239-
assert out.end.tolist() == [7, 7, 8, 6]
240-
assert out.strand.tolist() == ["+", "-", "+", "-"]
241-
assert out.matched_seq.tolist() == ["CACGTG", "CACGTG", "TGCGTG", "CACGCA"]
237+
238+
expected = pd.DataFrame({
239+
'motif': ['MA0004.1 Arnt', 'MA0004.1 Arnt','MA0006.1 Ahr::Arnt', 'MA0006.1 Ahr::Arnt'],
240+
'sequence': ['0', '0', '1', '2'],
241+
'seq_idx': [0, 0, 1, 2],
242+
'start': [1, 1, 0, 2],
243+
'end': [7, 7, 6, 8],
244+
'strand': ['+', '-', '-', '+'],
245+
'score': [11.60498046875, 11.60498046875, 10.691319823265076, 10.691319823265076],
246+
'p-value': [0.000244140625, 0.000244140625, 0.000244140625, 0.000244140625],
247+
'matched_seq': ['CACGTG', 'CACGTG', 'CACGCA', 'TGCGTG']
248+
})
249+
250+
assert out.equals(expected)
242251

243252
# Reverse complement with attributions
244253
attrs = get_attributions(model, seqs, method="inputxgradient")
245254
out = scan_sequences(seqs, motifs=meme_file, rc=True, pthresh=1e-3, attrs=attrs)
246-
assert np.allclose(out.site_attr_score, [0.0, 0.0, -0.009259, 0.009259], rtol=0.001)
247-
assert np.allclose(
248-
out.motif_attr_score, [0.003704, 0.0, -0.035494, 0.0], rtol=0.001
249-
)
255+
expected = pd.DataFrame({
256+
'motif': ['MA0004.1 Arnt', 'MA0004.1 Arnt', 'MA0006.1 Ahr::Arnt', 'MA0006.1 Ahr::Arnt'],
257+
'sequence': ['0', '0', '1', '2'],
258+
'seq_idx': [0, 0, 1, 2],
259+
'start': [1, 1, 0, 2],
260+
'end': [7, 7, 6, 8],
261+
'strand': ['+', '-', '-', '+'],
262+
'score': [11.60498046875, 11.60498046875, 10.691319823265076, 10.691319823265076],
263+
'p-value': [0.000244140625, 0.000244140625, 0.000244140625, 0.000244140625],
264+
'matched_seq': ['CACGTG', 'CACGTG', 'CACGCA', 'TGCGTG'],
265+
'site_attr_score': np.float32([0.0, 0.0, 0.009259258396923542, -0.009259259328246117]),
266+
'motif_attr_score': [0.003703703731298441, 0.0, 0.0, -0.03549381507926434]
267+
})
268+
assert out.equals(expected)
250269

251270

252271
def test_run_tomtom():

0 commit comments

Comments
 (0)