Skip to content

Commit 08b60e3

Browse files
author
lala8
committed
fixed ism function
1 parent b0a0f48 commit 08b60e3

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/polygraph/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,10 @@ def ism_score(model, seqs, batch_size, device="cpu", task=None):
305305

306306
# Reshape predictions : N, L, 4
307307
ism_preds = ism_preds.reshape(len(seqs), len(ism) // (len(seqs) * 4), 4)
308+
ism_preds = ism_preds.max(-1)
308309

309310
# Compute base-level importance score
310-
preds = np.log2(ism_preds / preds)
311-
preds = np.abs(preds).max(-1)
311+
preds = np.abs(ism_preds - preds)
312312
return preds
313313

314314

0 commit comments

Comments
 (0)