Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/grelu/interpret/modisco.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def run_modisco(
n_shuffles: int = 10,
seed=None,
method: str = "deepshap",
correct_grad: bool = False,
**kwargs,
) -> None:
"""
Expand All @@ -139,6 +140,8 @@ def run_modisco(
n_shuffles: Number of times to shuffle the background sequences for deepshap.
seed: Random seed
method: Either "deepshap", "saliency" or "ism".
correct_grad: If True, gradients will be corrected using the method of Majdandzic et al.
(PMID: 37161475). Only used with method='saliency'.
**kwargs: Additional arguments to pass to TF-Modisco.

Raises:
Expand Down Expand Up @@ -180,6 +183,7 @@ def run_modisco(
hypothetical=True,
genome=genome,
seed=seed,
correct_grad=correct_grad,
)
attrs = attrs[:, :, start:end]

Expand Down
17 changes: 15 additions & 2 deletions src/grelu/interpret/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def get_attributions(
prediction_transform: Optional[Callable] = None,
device: Union[str, int] = "cpu",
method: str = "deepshap",
correct_grad: bool = False,
hypothetical: bool = False,
n_shuffles: int = 20,
seed=None,
Expand All @@ -144,8 +145,10 @@ def get_attributions(
prediction_transform: A module to transform the model output
devices: Indices of the devices to use for inference
method: One of "deepshap", "saliency", "inputxgradient" or "integratedgradients"
hypothetical: whether to calculate hypothetical importance scores.
Set this to True to obtain input for tf-modisco, False otherwise
correct_grad: If True, gradients will be corrected using the method of Majdandzic et al.
(PMID: 37161475). Only used with method='saliency'.
hypothetical: Only used with method = "deepshap". If true, the function will return
hypothetical importance scores which can be used as input for tf-modisco.
n_shuffles: Number of times to dinucleotide shuffle sequence
seed: Random seed
**kwargs: Additional arguments to pass to tangermeme.deep_lift_shap.deep_lift_shap
Expand Down Expand Up @@ -210,6 +213,16 @@ def get_attributions(

# Remove transform
model.reset_transform()

# Correct gradients
if correct_grad:
if method != "saliency":
warnings.warn(
"correct_grad = True will be ignored as method is not saliency."
)
else:
attributions - attributions.mean(1, keepdims=True)

return attributions # N, 4, L


Expand Down