Skip to content

Commit 9921b42

Browse files
committed
XAI for HuggingFace vision models tutorial
1 parent 60cf39d commit 9921b42

File tree

11 files changed

+1079
-22
lines changed

11 files changed

+1079
-22
lines changed

cam.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
LayerCAM, \
1515
FullGrad, \
1616
GradCAMElementWise
17-
17+
1818

1919
from pytorch_grad_cam import GuidedBackpropReLUModel
2020
from pytorch_grad_cam.utils.image import show_cam_on_image, \
@@ -68,7 +68,7 @@ def get_args():
6868
args = get_args()
6969
methods = \
7070
{"gradcam": GradCAM,
71-
"hirescam":HiResCAM,
71+
"hirescam": HiResCAM,
7272
"scorecam": ScoreCAM,
7373
"gradcam++": GradCAMPlusPlus,
7474
"ablationcam": AblationCAM,
@@ -101,7 +101,6 @@ def get_args():
101101
mean=[0.485, 0.456, 0.406],
102102
std=[0.229, 0.224, 0.225])
103103

104-
105104
# We have to specify the target we want to generate
106105
# the Class Activation Maps for.
107106
# If targets is None, the highest scoring category (for every member in the batch) will be used.

pytorch_grad_cam/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from pytorch_grad_cam.fullgrad_cam import FullGrad
1414
from pytorch_grad_cam.guided_backprop import GuidedBackpropReLUModel
1515
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
16-
from pytorch_grad_cam.feature_factorization.deep_feature_factorization import DeepFeatureFactorization
16+
from pytorch_grad_cam.feature_factorization.deep_feature_factorization import DeepFeatureFactorization, run_dff_on_image
1717
import pytorch_grad_cam.utils.model_targets
1818
import pytorch_grad_cam.utils.reshape_transforms
1919
import pytorch_grad_cam.metrics.cam_mult_image
20-
import pytorch_grad_cam.metrics.road
20+
import pytorch_grad_cam.metrics.road

pytorch_grad_cam/feature_factorization/deep_feature_factorization.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
2+
from PIL import Image
23
import torch
3-
from typing import Callable, List, Tuple
4+
from typing import Callable, List, Tuple, Optional
45
from sklearn.decomposition import NMF
56
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
6-
from pytorch_grad_cam.utils.image import scale_cam_image
7+
from pytorch_grad_cam.utils.image import scale_cam_image, create_labels_legend, show_factorization_on_image
78

89

910
def dff(activations: np.ndarray, n_components: int = 5):
@@ -22,7 +23,7 @@ def dff(activations: np.ndarray, n_components: int = 5):
2223
reshaped_activations.shape[0], -1)
2324
offset = reshaped_activations.min(axis=-1)
2425
reshaped_activations = reshaped_activations - offset[:, None]
25-
26+
2627
model = NMF(n_components=n_components, init='random', random_state=0)
2728
W = model.fit_transform(reshaped_activations)
2829
H = model.components_
@@ -60,11 +61,15 @@ def __call__(self,
6061
n_components: int = 16):
6162
batch_size, channels, h, w = input_tensor.size()
6263
_ = self.activations_and_grads(input_tensor)
63-
activations = self.activations_and_grads.activations[0].cpu().numpy()
64+
65+
with torch.no_grad():
66+
activations = self.activations_and_grads.activations[0].cpu(
67+
).numpy()
68+
6469
concepts, explanations = dff(activations, n_components=n_components)
6570

6671
processed_explanations = []
67-
72+
6873
for batch in explanations:
6974
processed_explanations.append(scale_cam_image(batch, (w, h)))
7075

@@ -88,3 +93,39 @@ def __exit__(self, exc_type, exc_value, exc_tb):
8893
print(
8994
f"An exception occurred in ActivationSummary with block: {exc_type}. Message: {exc_value}")
9095
return True
96+
97+
98+
def run_dff_on_image(model: torch.nn.Module,
99+
target_layer: torch.nn.Module,
100+
classifier: torch.nn.Module,
101+
img_pil: Image,
102+
img_tensor: torch.Tensor,
103+
reshape_transform=Optional[Callable],
104+
n_components: int = 5,
105+
top_k: int = 2) -> np.ndarray:
106+
""" Helper function to create a Deep Feature Factorization visualization for a single image.
107+
TBD: Run this on a batch with several images.
108+
"""
109+
rgb_img_float = np.array(img_pil) / 255
110+
dff = DeepFeatureFactorization(model=model,
111+
reshape_transform=reshape_transform,
112+
target_layer=target_layer,
113+
computation_on_concepts=classifier)
114+
115+
concepts, batch_explanations, concept_outputs = dff(
116+
img_tensor[None, :], n_components)
117+
118+
concept_outputs = torch.softmax(
119+
torch.from_numpy(concept_outputs),
120+
axis=-1).numpy()
121+
concept_label_strings = create_labels_legend(concept_outputs,
122+
labels=model.config.id2label,
123+
top_k=top_k)
124+
visualization = show_factorization_on_image(
125+
rgb_img_float,
126+
batch_explanations[0],
127+
image_weight=0.3,
128+
concept_labels=concept_label_strings)
129+
130+
result = np.hstack((np.array(img_pil), visualization))
131+
return result

pytorch_grad_cam/utils/image.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import numpy as np
66
import torch
77
from torchvision.transforms import Compose, Normalize, ToTensor
8-
from typing import List
8+
from typing import List, Dict
9+
import math
910

1011

1112
def preprocess_image(
@@ -63,6 +64,22 @@ def show_cam_on_image(img: np.ndarray,
6364
return np.uint8(255 * cam)
6465

6566

67+
def create_labels_legend(concept_scores: np.ndarray,
68+
labels: Dict[int, str],
69+
top_k=2):
70+
concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k]
71+
concept_labels_topk = []
72+
for concept_index in range(concept_categories.shape[0]):
73+
categories = concept_categories[concept_index, :]
74+
concept_labels = []
75+
for category in categories:
76+
score = concept_scores[concept_index, category]
77+
label = f"{','.join(labels[category].split(',')[:3])}:{score:.2f}"
78+
concept_labels.append(label)
79+
concept_labels_topk.append("\n".join(concept_labels))
80+
return concept_labels_topk
81+
82+
6683
def show_factorization_on_image(img: np.ndarray,
6784
explanations: np.ndarray,
6885
colors: List[np.ndarray] = None,
@@ -118,7 +135,8 @@ def show_factorization_on_image(img: np.ndarray,
118135
if concept_labels is not None:
119136
px = 1 / plt.rcParams['figure.dpi'] # pixel in inches
120137
fig = plt.figure(figsize=(result.shape[1] * px, result.shape[0] * px))
121-
plt.rcParams['legend.fontsize'] = 15 * result.shape[0] / 256
138+
plt.rcParams['legend.fontsize'] = int(
139+
14 * result.shape[0] / 256 / max(1, n_components / 6))
122140
lw = 5 * result.shape[0] / 256
123141
lines = [Line2D([0], [0], color=colors[i], lw=lw)
124142
for i in range(n_components)]

setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
setuptools.setup(
1010
name='grad-cam',
11-
version='1.4.5',
11+
version='1.4.6',
1212
author='Jacob Gildenblat',
1313
author_email='[email protected]',
1414
description='Many Class Activation Map methods implemented in Pytorch for classification, segmentation, object detection and more',
@@ -23,7 +23,7 @@
2323
'License :: OSI Approved :: MIT License',
2424
'Operating System :: OS Independent',
2525
],
26-
packages=setuptools.find_packages(exclude=["*tutorials*"]),
26+
packages=setuptools.find_packages(
27+
exclude=["*tutorials*"]),
2728
python_requires='>=3.6',
28-
install_requires=requirements
29-
)
29+
install_requires=requirements)

tests/test_context_release.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def test_memory_usage_in_loop(numpy_image, batch_size, width, height,
5252
target_layers = []
5353
for layer in target_layer_names:
5454
target_layers.append(eval(f"model.{layer}"))
55-
targets = [ClassifierOutputTarget(target_category) for _ in range(batch_size)]
55+
targets = [ClassifierOutputTarget(target_category)
56+
for _ in range(batch_size)]
5657
initial_memory = 0
5758
for i in range(100):
5859
with cam_method(model=model,

tests/test_one_channel.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import pytest
2+
import torchvision
3+
import torch
4+
import cv2
5+
import numpy as np
6+
from pytorch_grad_cam import GradCAM, \
7+
ScoreCAM, \
8+
GradCAMPlusPlus, \
9+
AblationCAM, \
10+
XGradCAM, \
11+
EigenCAM, \
12+
EigenGradCAM, \
13+
LayerCAM, \
14+
FullGrad
15+
from pytorch_grad_cam.utils.image import show_cam_on_image, \
16+
preprocess_image
17+
18+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
19+
20+
torch.manual_seed(0)
21+
22+
23+
@pytest.fixture
24+
def numpy_image():
25+
return cv2.imread("examples/both.png")
26+
27+
28+
@pytest.mark.parametrize("cam_method",
29+
[GradCAM])
30+
def test_memory_usage_in_loop(numpy_image, cam_method):
31+
model = torchvision.models.resnet18(pretrained=False)
32+
model.conv1 = torch.nn.Conv2d(
33+
1, 64, kernel_size=(
34+
7, 7), stride=(
35+
2, 2), padding=(
36+
3, 3), bias=False)
37+
target_layers = [model.layer4]
38+
gray_img = numpy_image[:, :, 0]
39+
input_tensor = torch.from_numpy(
40+
np.float32(gray_img)).unsqueeze(0).unsqueeze(0)
41+
input_tensor = input_tensor.repeat(16, 1, 1, 1)
42+
print("input_tensor", input_tensor.shape)
43+
targets = None
44+
with cam_method(model=model,
45+
target_layers=target_layers,
46+
use_cuda=False) as cam:
47+
grayscale_cam = cam(input_tensor=input_tensor,
48+
targets=targets)
49+
print(grayscale_cam.shape)

tests/test_run_all_models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ def test_all_cam_models_can_run(numpy_image, batch_size, width, height,
6868
use_cuda=False)
6969
cam.batch_size = 4
7070
if target_category is None:
71-
targets = None
71+
targets = None
7272
else:
73-
targets = [ClassifierOutputTarget(target_category) for _ in range(batch_size)]
73+
targets = [ClassifierOutputTarget(target_category)
74+
for _ in range(batch_size)]
7475

7576
grayscale_cam = cam(input_tensor=input_tensor,
7677
targets=targets,

0 commit comments

Comments
 (0)