Skip to content

Commit ee5592d

Browse files
authored
Merge pull request #106 from Genentech/correct-gradients
added gradient correction
2 parents 28a9d3b + 3e26c3c commit ee5592d

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

src/grelu/interpret/modisco.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def run_modisco(
119119
n_shuffles: int = 10,
120120
seed=None,
121121
method: str = "deepshap",
122+
correct_grad: bool = False,
122123
**kwargs,
123124
) -> None:
124125
"""
@@ -139,6 +140,8 @@ def run_modisco(
139140
n_shuffles: Number of times to shuffle the background sequences for deepshap.
140141
seed: Random seed
141142
method: Either "deepshap", "saliency" or "ism".
143+
correct_grad: If True, gradients will be corrected using the method of Majdandzic et al.
144+
(PMID: 37161475). Only used with method='saliency'.
142145
**kwargs: Additional arguments to pass to TF-Modisco.
143146
144147
Raises:
@@ -180,6 +183,7 @@ def run_modisco(
180183
hypothetical=True,
181184
genome=genome,
182185
seed=seed,
186+
correct_grad=correct_grad,
183187
)
184188
attrs = attrs[:, :, start:end]
185189

src/grelu/interpret/score.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def get_attributions(
129129
prediction_transform: Optional[Callable] = None,
130130
device: Union[str, int] = "cpu",
131131
method: str = "deepshap",
132+
correct_grad: bool = False,
132133
hypothetical: bool = False,
133134
n_shuffles: int = 20,
134135
seed=None,
@@ -144,8 +145,10 @@ def get_attributions(
144145
prediction_transform: A module to transform the model output
145146
devices: Indices of the devices to use for inference
146147
method: One of "deepshap", "saliency", "inputxgradient" or "integratedgradients"
147-
hypothetical: whether to calculate hypothetical importance scores.
148-
Set this to True to obtain input for tf-modisco, False otherwise
148+
correct_grad: If True, gradients will be corrected using the method of Majdandzic et al.
149+
(PMID: 37161475). Only used with method='saliency'.
150+
hypothetical: Only used with method = "deepshap". If true, the function will return
151+
hypothetical importance scores which can be used as input for tf-modisco.
149152
n_shuffles: Number of times to dinucleotide shuffle sequence
150153
seed: Random seed
151154
**kwargs: Additional arguments to pass to tangermeme.deep_lift_shap.deep_lift_shap
@@ -210,6 +213,16 @@ def get_attributions(
210213

211214
# Remove transform
212215
model.reset_transform()
216+
217+
# Correct gradients
218+
if correct_grad:
219+
if method != "saliency":
220+
warnings.warn(
221+
"correct_grad = True will be ignored as method is not saliency."
222+
)
223+
else:
224+
attributions - attributions.mean(1, keepdims=True)
225+
213226
return attributions # N, 4, L
214227

215228

0 commit comments

Comments
 (0)