@@ -129,6 +129,7 @@ def get_attributions(
129129 prediction_transform : Optional [Callable ] = None ,
130130 device : Union [str , int ] = "cpu" ,
131131 method : str = "deepshap" ,
132+ correct_grad : bool = False ,
132133 hypothetical : bool = False ,
133134 n_shuffles : int = 20 ,
134135 seed = None ,
@@ -144,8 +145,10 @@ def get_attributions(
144145 prediction_transform: A module to transform the model output
145146 devices: Indices of the devices to use for inference
146147 method: One of "deepshap", "saliency", "inputxgradient" or "integratedgradients"
147- hypothetical: whether to calculate hypothetical importance scores.
148- Set this to True to obtain input for tf-modisco, False otherwise
148+ correct_grad: If True, gradients will be corrected using the method of Majdandzic et al.
149+ (PMID: 37161475). Only used with method='saliency'.
150+ hypothetical: Only used with method = "deepshap". If true, the function will return
151+ hypothetical importance scores which can be used as input for tf-modisco.
149152 n_shuffles: Number of times to dinucleotide shuffle sequence
150153 seed: Random seed
151154 **kwargs: Additional arguments to pass to tangermeme.deep_lift_shap.deep_lift_shap
@@ -210,6 +213,16 @@ def get_attributions(
210213
211214 # Remove transform
212215 model .reset_transform ()
216+
217+ # Correct gradients
218+ if correct_grad :
219+ if method != "saliency" :
220+ warnings .warn (
221+ "correct_grad = True will be ignored as method is not saliency."
222+ )
223+ else :
224+ attributions - attributions .mean (1 , keepdims = True )
225+
213226 return attributions # N, 4, L
214227
215228
0 commit comments