Skip to content

Commit 9c4a37c

Browse files
avantikalallala8
andauthored
added robustness analysis (#17)
Co-authored-by: lala8 <[email protected]>
1 parent d164fd7 commit 9c4a37c

File tree

2 files changed

+99
-17
lines changed

2 files changed

+99
-17
lines changed

src/polygraph/models.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -279,24 +279,92 @@ def ism_score(model, seqs, batch_size, device="cpu", task=None):
279279

280280
assert check_equal_lens(seqs)
281281

282+
# Predictions on original sequences
283+
preds = predict(seqs=seqs, model=model, batch_size=batch_size, device=device)
284+
assert preds.ndim < 3
285+
286+
# Select relevant task/cell type, or average predictions
287+
if task is None:
288+
if preds.ndim == 2:
289+
preds = preds.mean(1, keepdims=True)
290+
else:
291+
preds = preds[:, [task]]
292+
282293
# Mutate sequences
283294
ism = ISM(seqs) # N x L x 4
284295

285296
# Make predictions on mutated sequences
286-
preds = predict(seqs=ism, model=model, batch_size=batch_size, device=device)
287-
assert preds.ndim < 3
297+
ism_preds = predict(seqs=ism, model=model, batch_size=batch_size, device=device)
288298

289299
# Select relevant task/cell type, or average predictions
290300
if task is None:
291-
if preds.ndim == 2:
292-
preds = preds.mean(1)
301+
if ism_preds.ndim == 2:
302+
ism_preds = ism_preds.mean(1)
293303
else:
294-
preds = preds[:, task]
304+
ism_preds = ism_preds[:, task]
295305

296306
# Reshape predictions : N, L, 4
297-
preds = preds.reshape(len(seqs), len(ism) // (len(seqs) * 4), 4)
307+
ism_preds = ism_preds.reshape(len(seqs), len(ism) // (len(seqs) * 4), 4)
298308

299309
# Compute base-level importance score
300-
preds = np.log2(preds / preds.mean(-1, keepdims=True))
310+
preds = np.log2(ism_preds / preds)
301311
preds = np.abs(preds).max(-1)
302312
return preds
313+
314+
315+
def robustness(model, seqs, batch_size, device="cpu", task=None, aggfunc="mean"):
316+
"""
317+
Get robustness scores for given sequence(s) using ISM
318+
319+
Args:
320+
seqs (list, pd.DataFrame): List of sequences or dataframe
321+
containing sequences in the column "Sequence".
322+
model (nn.Sequential): trained model
323+
batch_size (int): Batch size for inference
324+
device (str, int): ID of GPU to perform inference.
325+
aggfunc (str): Either 'mean' or 'max'. Determines how to aggregate the
326+
effect of all possible single-base mutations.
327+
328+
Returns:
329+
(pd.DataFrame): DataFrame of shape (n_seqs x n_outputs)
330+
"""
331+
from polygraph.sequence import ISM
332+
from polygraph.utils import check_equal_lens
333+
334+
assert check_equal_lens(seqs)
335+
336+
# Predictions on original sequences
337+
preds = predict(seqs=seqs, model=model, batch_size=batch_size, device=device)
338+
assert preds.ndim < 3
339+
340+
# Select relevant task/cell type, or average predictions
341+
if task is None:
342+
if preds.ndim == 2:
343+
preds = preds.mean(1, keepdims=True)
344+
else:
345+
preds = preds[:, [task]]
346+
347+
# Mutate sequences
348+
ism = ISM(seqs, drop_ref=True) # N x L x 3
349+
350+
# Make predictions on mutated sequences
351+
ism_preds = predict(seqs=ism, model=model, batch_size=batch_size, device=device)
352+
353+
# Select relevant task/cell type, or average predictions
354+
if task is None:
355+
if ism_preds.ndim == 2:
356+
ism_preds = ism_preds.mean(1)
357+
else:
358+
ism_preds = ism_preds[:, task]
359+
360+
# Reshape predictions : N, Lx3
361+
ism_preds = ism_preds.reshape(len(seqs), len(ism) // len(seqs))
362+
363+
# Compare mutated sequences to originals
364+
deltas = np.abs((ism_preds / preds) - 1)
365+
366+
# Aggregate over all possible mutations
367+
if aggfunc == "mean":
368+
return np.mean(deltas, 1)
369+
elif aggfunc == "max":
370+
return np.max(deltas, 1)

src/polygraph/sequence.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -261,36 +261,50 @@ def fastsk(seqs, k=5, m=2):
261261
return np.array(kernel.get_train_kernel())
262262

263263

264-
def ISM(seqs):
264+
def ISM(seqs, drop_ref=False):
265265
"""
266266
Perform in-silico mutagenesis on given DNA sequence(s)
267267
268268
Args:
269269
seqs (str, list, pd.DataFrame): A DNA sequence, list of sequences
270270
or dataframe containing sequences in the column "Sequence".
271+
drop_ref (bool): If True, do not return the original sequence.
271272
272273
Returns:
273274
(list): A list of all possible single-base mutated sequences
274275
derived from the original sequences.
275276
"""
276277
# ISM for a single sequence
277278
if isinstance(seqs, str):
278-
return list(
279-
np.concatenate(
280-
[
281-
[seqs[:pos] + base + seqs[pos + 1 :] for base in STANDARD_BASES]
282-
for pos in range(len(seqs))
283-
]
279+
if drop_ref:
280+
return list(
281+
np.concatenate(
282+
[
283+
[
284+
seqs[:pos] + base + seqs[pos + 1 :]
285+
for base in [x for x in STANDARD_BASES if x != b]
286+
]
287+
for pos, b in enumerate(seqs)
288+
]
289+
)
290+
)
291+
else:
292+
return list(
293+
np.concatenate(
294+
[
295+
[seqs[:pos] + base + seqs[pos + 1 :] for base in STANDARD_BASES]
296+
for pos in range(len(seqs))
297+
]
298+
)
284299
)
285-
)
286300

287301
# Multiple sequences
288302
elif isinstance(seqs, list):
289-
return list(np.concatenate([ISM(seq) for seq in seqs]))
303+
return list(np.concatenate([ISM(seq, drop_ref=drop_ref) for seq in seqs]))
290304

291305
# For a dataframe, copy the index
292306
elif isinstance(seqs, pd.DataFrame):
293-
return ISM(seqs.Sequence.tolist())
307+
return ISM(seqs.Sequence.tolist(), drop_ref=drop_ref)
294308

295309
else:
296310
raise TypeError("seqs must be a string, list or dataframe.")

0 commit comments

Comments
 (0)