Skip to content

Commit 3f6b14d

Browse files
Support for 3D Conv-Net (#466)
* Modify grad-cam and base-cam to support 3d conv. * Add image examples for 3D convolutions. * Modify get_cam_image to increase readbability. --------- Co-authored-by: Jacob Gildenblat <[email protected]>
1 parent f0371ab commit 3f6b14d

File tree

5 files changed

+104
-91
lines changed

5 files changed

+104
-91
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ The aim is also to serve as a benchmark of algorithms and metrics for research o
5757
| -----------------|-----------------------|
5858
| <img src="./examples/both_detection.png" width="256" height="256"> | <img src="./examples/cars_segmentation.png" width="256" height="200"> |
5959

60+
| Semantic Segmentation (3D) |
61+
| -------------------------- |
62+
| <img src="./examples/multiorgan_segmentation.gif" width="539">|
63+
6064
## Explaining similarity to other images / embeddings
6165
<img src="./examples/embeddings.png">
6266

22.3 MB
Loading

pytorch_grad_cam/base_cam.py

Lines changed: 77 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
1+
from typing import Callable, List, Optional, Tuple
2+
13
import numpy as np
24
import torch
35
import ttach as tta
4-
from typing import Callable, List, Tuple, Optional
6+
57
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
6-
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
78
from pytorch_grad_cam.utils.image import scale_cam_image
89
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
10+
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
911

1012

1113
class BaseCAM:
12-
def __init__(self,
13-
model: torch.nn.Module,
14-
target_layers: List[torch.nn.Module],
15-
reshape_transform: Callable = None,
16-
compute_input_gradient: bool = False,
17-
uses_gradients: bool = True,
18-
tta_transforms: Optional[tta.Compose] = None) -> None:
14+
def __init__(
15+
self,
16+
model: torch.nn.Module,
17+
target_layers: List[torch.nn.Module],
18+
reshape_transform: Callable = None,
19+
compute_input_gradient: bool = False,
20+
uses_gradients: bool = True,
21+
tta_transforms: Optional[tta.Compose] = None,
22+
) -> None:
1923
self.model = model.eval()
2024
self.target_layers = target_layers
2125

@@ -34,63 +38,64 @@ def __init__(self,
3438
else:
3539
self.tta_transforms = tta_transforms
3640

37-
self.activations_and_grads = ActivationsAndGradients(
38-
self.model, target_layers, reshape_transform)
41+
self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform)
3942

4043
""" Get a vector of weights for every channel in the target layer.
4144
Methods that return weights channels,
4245
will typically need to only implement this function. """
4346

44-
def get_cam_weights(self,
45-
input_tensor: torch.Tensor,
46-
target_layers: List[torch.nn.Module],
47-
targets: List[torch.nn.Module],
48-
activations: torch.Tensor,
49-
grads: torch.Tensor) -> np.ndarray:
47+
def get_cam_weights(
48+
self,
49+
input_tensor: torch.Tensor,
50+
target_layers: List[torch.nn.Module],
51+
targets: List[torch.nn.Module],
52+
activations: torch.Tensor,
53+
grads: torch.Tensor,
54+
) -> np.ndarray:
5055
raise Exception("Not Implemented")
5156

52-
def get_cam_image(self,
53-
input_tensor: torch.Tensor,
54-
target_layer: torch.nn.Module,
55-
targets: List[torch.nn.Module],
56-
activations: torch.Tensor,
57-
grads: torch.Tensor,
58-
eigen_smooth: bool = False) -> np.ndarray:
59-
60-
weights = self.get_cam_weights(input_tensor,
61-
target_layer,
62-
targets,
63-
activations,
64-
grads)
65-
weighted_activations = weights[:, :, None, None] * activations
57+
def get_cam_image(
58+
self,
59+
input_tensor: torch.Tensor,
60+
target_layer: torch.nn.Module,
61+
targets: List[torch.nn.Module],
62+
activations: torch.Tensor,
63+
grads: torch.Tensor,
64+
eigen_smooth: bool = False,
65+
) -> np.ndarray:
66+
weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)
67+
# 2D conv
68+
if len(activations.shape) == 4:
69+
weighted_activations = weights[:, :, None, None] * activations
70+
# 3D conv
71+
elif len(activations.shape) == 5:
72+
weighted_activations = weights[:, :, None, None, None] * activations
73+
else:
74+
raise ValueError(f"Invalid activation shape. Get {len(activations.shape)}.")
75+
6676
if eigen_smooth:
6777
cam = get_2d_projection(weighted_activations)
6878
else:
6979
cam = weighted_activations.sum(axis=1)
7080
return cam
7181

72-
def forward(self,
73-
input_tensor: torch.Tensor,
74-
targets: List[torch.nn.Module],
75-
eigen_smooth: bool = False) -> np.ndarray:
76-
82+
def forward(
83+
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False
84+
) -> np.ndarray:
7785
input_tensor = input_tensor.to(self.device)
7886

7987
if self.compute_input_gradient:
80-
input_tensor = torch.autograd.Variable(input_tensor,
81-
requires_grad=True)
88+
input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)
8289

8390
self.outputs = outputs = self.activations_and_grads(input_tensor)
8491

8592
if targets is None:
8693
target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
87-
targets = [ClassifierOutputTarget(
88-
category) for category in target_categories]
94+
targets = [ClassifierOutputTarget(category) for category in target_categories]
8995

9096
if self.uses_gradients:
9197
self.model.zero_grad()
92-
loss = sum([target(output)
93-
for target, output in zip(targets, outputs)])
98+
loss = sum([target(output) for target, output in zip(targets, outputs)])
9499
loss.backward(retain_graph=True)
95100

96101
# In most of the saliency attribution papers, the saliency is
@@ -102,25 +107,24 @@ def forward(self,
102107
# This gives you more flexibility in case you just want to
103108
# use all conv layers for example, all Batchnorm layers,
104109
# or something else.
105-
cam_per_layer = self.compute_cam_per_layer(input_tensor,
106-
targets,
107-
eigen_smooth)
110+
cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth)
108111
return self.aggregate_multi_layers(cam_per_layer)
109112

110-
def get_target_width_height(self,
111-
input_tensor: torch.Tensor) -> Tuple[int, int]:
112-
width, height = input_tensor.size(-1), input_tensor.size(-2)
113-
return width, height
113+
def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int]:
114+
if len(input_tensor.shape) == 4:
115+
width, height = input_tensor.size(-1), input_tensor.size(-2)
116+
return width, height
117+
elif len(input_tensor.shape) == 5:
118+
depth, width, height = input_tensor.size(-1), input_tensor.size(-2), input_tensor.size(-3)
119+
return depth, width, height
120+
else:
121+
raise ValueError("Invalid input_tensor shape. Only 2D or 3D images are supported.")
114122

115123
def compute_cam_per_layer(
116-
self,
117-
input_tensor: torch.Tensor,
118-
targets: List[torch.nn.Module],
119-
eigen_smooth: bool) -> np.ndarray:
120-
activations_list = [a.cpu().data.numpy()
121-
for a in self.activations_and_grads.activations]
122-
grads_list = [g.cpu().data.numpy()
123-
for g in self.activations_and_grads.gradients]
124+
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool
125+
) -> np.ndarray:
126+
activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
127+
grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
124128
target_size = self.get_target_width_height(input_tensor)
125129

126130
cam_per_target_layer = []
@@ -134,36 +138,26 @@ def compute_cam_per_layer(
134138
if i < len(grads_list):
135139
layer_grads = grads_list[i]
136140

137-
cam = self.get_cam_image(input_tensor,
138-
target_layer,
139-
targets,
140-
layer_activations,
141-
layer_grads,
142-
eigen_smooth)
141+
cam = self.get_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth)
143142
cam = np.maximum(cam, 0)
144143
scaled = scale_cam_image(cam, target_size)
145144
cam_per_target_layer.append(scaled[:, None, :])
146145

147146
return cam_per_target_layer
148147

149-
def aggregate_multi_layers(
150-
self,
151-
cam_per_target_layer: np.ndarray) -> np.ndarray:
148+
def aggregate_multi_layers(self, cam_per_target_layer: np.ndarray) -> np.ndarray:
152149
cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
153150
cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
154151
result = np.mean(cam_per_target_layer, axis=1)
155152
return scale_cam_image(result)
156153

157-
def forward_augmentation_smoothing(self,
158-
input_tensor: torch.Tensor,
159-
targets: List[torch.nn.Module],
160-
eigen_smooth: bool = False) -> np.ndarray:
154+
def forward_augmentation_smoothing(
155+
self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False
156+
) -> np.ndarray:
161157
cams = []
162158
for transform in self.tta_transforms:
163159
augmented_tensor = transform.augment_image(input_tensor)
164-
cam = self.forward(augmented_tensor,
165-
targets,
166-
eigen_smooth)
160+
cam = self.forward(augmented_tensor, targets, eigen_smooth)
167161

168162
# The ttach library expects a tensor of size BxCxHxW
169163
cam = cam[:, None, :, :]
@@ -178,19 +172,18 @@ def forward_augmentation_smoothing(self,
178172
cam = np.mean(np.float32(cams), axis=0)
179173
return cam
180174

181-
def __call__(self,
182-
input_tensor: torch.Tensor,
183-
targets: List[torch.nn.Module] = None,
184-
aug_smooth: bool = False,
185-
eigen_smooth: bool = False) -> np.ndarray:
186-
175+
def __call__(
176+
self,
177+
input_tensor: torch.Tensor,
178+
targets: List[torch.nn.Module] = None,
179+
aug_smooth: bool = False,
180+
eigen_smooth: bool = False,
181+
) -> np.ndarray:
187182
# Smooth the CAM result with test time augmentation
188183
if aug_smooth is True:
189-
return self.forward_augmentation_smoothing(
190-
input_tensor, targets, eigen_smooth)
184+
return self.forward_augmentation_smoothing(input_tensor, targets, eigen_smooth)
191185

192-
return self.forward(input_tensor,
193-
targets, eigen_smooth)
186+
return self.forward(input_tensor, targets, eigen_smooth)
194187

195188
def __del__(self):
196189
self.activations_and_grads.release()
@@ -202,6 +195,5 @@ def __exit__(self, exc_type, exc_value, exc_tb):
202195
self.activations_and_grads.release()
203196
if isinstance(exc_value, IndexError):
204197
# Handle IndexError here...
205-
print(
206-
f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
198+
print(f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
207199
return True

pytorch_grad_cam/grad_cam.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
23
from pytorch_grad_cam.base_cam import BaseCAM
34

45

@@ -18,4 +19,14 @@ def get_cam_weights(self,
1819
target_category,
1920
activations,
2021
grads):
21-
return np.mean(grads, axis=(2, 3))
22+
# 2D image
23+
if len(grads.shape) == 4:
24+
return np.mean(grads, axis=(2, 3))
25+
26+
# 3D image
27+
elif len(grads.shape) == 5:
28+
return np.mean(grads, axis=(2, 3, 4))
29+
30+
else:
31+
raise ValueError("Invalid grads shape."
32+
"Shape of grads should be 4 (2D image) or 5 (3D image).")

pytorch_grad_cam/utils/image.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
import matplotlib
2-
from matplotlib import pyplot as plt
3-
from matplotlib.lines import Line2D
1+
import math
2+
from typing import Dict, List
3+
44
import cv2
5+
import matplotlib
56
import numpy as np
67
import torch
8+
from matplotlib import pyplot as plt
9+
from matplotlib.lines import Line2D
10+
from scipy.ndimage import zoom
711
from torchvision.transforms import Compose, Normalize, ToTensor
8-
from typing import List, Dict
9-
import math
1012

1113

1214
def preprocess_image(
@@ -163,7 +165,11 @@ def scale_cam_image(cam, target_size=None):
163165
img = img - np.min(img)
164166
img = img / (1e-7 + np.max(img))
165167
if target_size is not None:
168+
if len(img.shape) > 3:
169+
img = zoom(np.float32(img), [(t_s/i_s) for i_s, t_s in zip(img.shape, target_size[::-1])])
170+
else:
166171
img = cv2.resize(np.float32(img), target_size)
172+
167173
result.append(img)
168174
result = np.float32(result)
169175

0 commit comments

Comments
 (0)