1
1
import numpy as np
2
+ from PIL import Image
2
3
import torch
3
- from typing import Callable , List , Tuple
4
+ from typing import Callable , List , Tuple , Optional
4
5
from sklearn .decomposition import NMF
5
6
from pytorch_grad_cam .activations_and_gradients import ActivationsAndGradients
6
- from pytorch_grad_cam .utils .image import scale_cam_image
7
+ from pytorch_grad_cam .utils .image import scale_cam_image , create_labels_legend , show_factorization_on_image
7
8
8
9
9
10
def dff (activations : np .ndarray , n_components : int = 5 ):
@@ -22,7 +23,7 @@ def dff(activations: np.ndarray, n_components: int = 5):
22
23
reshaped_activations .shape [0 ], - 1 )
23
24
offset = reshaped_activations .min (axis = - 1 )
24
25
reshaped_activations = reshaped_activations - offset [:, None ]
25
-
26
+
26
27
model = NMF (n_components = n_components , init = 'random' , random_state = 0 )
27
28
W = model .fit_transform (reshaped_activations )
28
29
H = model .components_
@@ -60,11 +61,15 @@ def __call__(self,
60
61
n_components : int = 16 ):
61
62
batch_size , channels , h , w = input_tensor .size ()
62
63
_ = self .activations_and_grads (input_tensor )
63
- activations = self .activations_and_grads .activations [0 ].cpu ().numpy ()
64
+
65
+ with torch .no_grad ():
66
+ activations = self .activations_and_grads .activations [0 ].cpu (
67
+ ).numpy ()
68
+
64
69
concepts , explanations = dff (activations , n_components = n_components )
65
70
66
71
processed_explanations = []
67
-
72
+
68
73
for batch in explanations :
69
74
processed_explanations .append (scale_cam_image (batch , (w , h )))
70
75
@@ -88,3 +93,39 @@ def __exit__(self, exc_type, exc_value, exc_tb):
88
93
print (
89
94
f"An exception occurred in ActivationSummary with block: { exc_type } . Message: { exc_value } " )
90
95
return True
96
+
97
+
98
+ def run_dff_on_image (model : torch .nn .Module ,
99
+ target_layer : torch .nn .Module ,
100
+ classifier : torch .nn .Module ,
101
+ img_pil : Image ,
102
+ img_tensor : torch .Tensor ,
103
+ reshape_transform = Optional [Callable ],
104
+ n_components : int = 5 ,
105
+ top_k : int = 2 ) -> np .ndarray :
106
+ """ Helper function to create a Deep Feature Factorization visualization for a single image.
107
+ TBD: Run this on a batch with several images.
108
+ """
109
+ rgb_img_float = np .array (img_pil ) / 255
110
+ dff = DeepFeatureFactorization (model = model ,
111
+ reshape_transform = reshape_transform ,
112
+ target_layer = target_layer ,
113
+ computation_on_concepts = classifier )
114
+
115
+ concepts , batch_explanations , concept_outputs = dff (
116
+ img_tensor [None , :], n_components )
117
+
118
+ concept_outputs = torch .softmax (
119
+ torch .from_numpy (concept_outputs ),
120
+ axis = - 1 ).numpy ()
121
+ concept_label_strings = create_labels_legend (concept_outputs ,
122
+ labels = model .config .id2label ,
123
+ top_k = top_k )
124
+ visualization = show_factorization_on_image (
125
+ rgb_img_float ,
126
+ batch_explanations [0 ],
127
+ image_weight = 0.3 ,
128
+ concept_labels = concept_label_strings )
129
+
130
+ result = np .hstack ((np .array (img_pil ), visualization ))
131
+ return result
0 commit comments