@@ -279,24 +279,92 @@ def ism_score(model, seqs, batch_size, device="cpu", task=None):
279279
280280 assert check_equal_lens (seqs )
281281
282+ # Predictions on original sequences
283+ preds = predict (seqs = seqs , model = model , batch_size = batch_size , device = device )
284+ assert preds .ndim < 3
285+
286+ # Select relevant task/cell type, or average predictions
287+ if task is None :
288+ if preds .ndim == 2 :
289+ preds = preds .mean (1 , keepdims = True )
290+ else :
291+ preds = preds [:, [task ]]
292+
282293 # Mutate sequences
283294 ism = ISM (seqs ) # N x L x 4
284295
285296 # Make predictions on mutated sequences
286- preds = predict (seqs = ism , model = model , batch_size = batch_size , device = device )
287- assert preds .ndim < 3
297+ ism_preds = predict (seqs = ism , model = model , batch_size = batch_size , device = device )
288298
289299 # Select relevant task/cell type, or average predictions
290300 if task is None :
291- if preds .ndim == 2 :
292- preds = preds .mean (1 )
301+ if ism_preds .ndim == 2 :
302+ ism_preds = ism_preds .mean (1 )
293303 else :
294- preds = preds [:, task ]
304+ ism_preds = ism_preds [:, task ]
295305
296306 # Reshape predictions : N, L, 4
297- preds = preds .reshape (len (seqs ), len (ism ) // (len (seqs ) * 4 ), 4 )
307+ ism_preds = ism_preds .reshape (len (seqs ), len (ism ) // (len (seqs ) * 4 ), 4 )
298308
299309 # Compute base-level importance score
300- preds = np .log2 (preds / preds . mean ( - 1 , keepdims = True ) )
310+ preds = np .log2 (ism_preds / preds )
301311 preds = np .abs (preds ).max (- 1 )
302312 return preds
313+
314+
315+ def robustness (model , seqs , batch_size , device = "cpu" , task = None , aggfunc = "mean" ):
316+ """
317+ Get robustness scores for given sequence(s) using ISM
318+
319+ Args:
320+ seqs (list, pd.DataFrame): List of sequences or dataframe
321+ containing sequences in the column "Sequence".
322+ model (nn.Sequential): trained model
323+ batch_size (int): Batch size for inference
324+ device (str, int): ID of GPU to perform inference.
325+ aggfunc (str): Either 'mean' or 'max'. Determines how to aggregate the
326+ effect of all possible single-base mutations.
327+
328+ Returns:
329+ (pd.DataFrame): DataFrame of shape (n_seqs x n_outputs)
330+ """
331+ from polygraph .sequence import ISM
332+ from polygraph .utils import check_equal_lens
333+
334+ assert check_equal_lens (seqs )
335+
336+ # Predictions on original sequences
337+ preds = predict (seqs = seqs , model = model , batch_size = batch_size , device = device )
338+ assert preds .ndim < 3
339+
340+ # Select relevant task/cell type, or average predictions
341+ if task is None :
342+ if preds .ndim == 2 :
343+ preds = preds .mean (1 , keepdims = True )
344+ else :
345+ preds = preds [:, [task ]]
346+
347+ # Mutate sequences
348+ ism = ISM (seqs , drop_ref = True ) # N x L x 3
349+
350+ # Make predictions on mutated sequences
351+ ism_preds = predict (seqs = ism , model = model , batch_size = batch_size , device = device )
352+
353+ # Select relevant task/cell type, or average predictions
354+ if task is None :
355+ if ism_preds .ndim == 2 :
356+ ism_preds = ism_preds .mean (1 )
357+ else :
358+ ism_preds = ism_preds [:, task ]
359+
360+ # Reshape predictions : N, Lx3
361+ ism_preds = ism_preds .reshape (len (seqs ), len (ism ) // len (seqs ))
362+
363+ # Compare mutated sequences to originals
364+ deltas = np .abs ((ism_preds / preds ) - 1 )
365+
366+ # Aggregate over all possible mutations
367+ if aggfunc == "mean" :
368+ return np .mean (deltas , 1 )
369+ elif aggfunc == "max" :
370+ return np .max (deltas , 1 )
0 commit comments