Skip to content

Commit 3f6c1fb

Browse files
authored
added support for saving output in output folder, (#439)
* Update cam.py added support for output_dir argument * Update README.md updated redme & added support for storing output in customoutputdir. python cam.py --image-path input.jpg --method gradcam --output-dir custom_output * Create cam_custom.py add your own custom_model * Delete cam_custom.py
1 parent 58a565a commit 3f6c1fb

File tree

2 files changed

+50
-47
lines changed

2 files changed

+50
-47
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,11 @@ two smoothing methods are supported:
238238

239239
# Running the example script:
240240

241-
Usage: `python cam.py --image-path <path_to_image> --method <method>`
241+
Usage: `python cam.py --image-path <path_to_image> --method <method> --output-dir <output_dir_path> `
242+
242243

243244
To use with CUDA:
244-
`python cam.py --image-path <path_to_image> --use-cuda`
245+
`python cam.py --image-path <path_to_image> --use-cuda --output-dir <output_dir_path> `
245246

246247
----------
247248

cam.py

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,18 @@
11
import argparse
2+
import os
23
import cv2
34
import numpy as np
45
import torch
56
from torchvision import models
6-
from pytorch_grad_cam import GradCAM, \
7-
HiResCAM, \
8-
ScoreCAM, \
9-
GradCAMPlusPlus, \
10-
AblationCAM, \
11-
XGradCAM, \
12-
EigenCAM, \
13-
EigenGradCAM, \
14-
LayerCAM, \
15-
FullGrad, \
16-
GradCAMElementWise
17-
18-
7+
from pytorch_grad_cam import (
8+
GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
9+
AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
10+
LayerCAM, FullGrad, GradCAMElementWise
11+
)
1912
from pytorch_grad_cam import GuidedBackpropReLUModel
20-
from pytorch_grad_cam.utils.image import show_cam_on_image, \
21-
deprocess_image, \
22-
preprocess_image
13+
from pytorch_grad_cam.utils.image import (
14+
show_cam_on_image, deprocess_image, preprocess_image
15+
)
2316
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
2417

2518

@@ -32,21 +25,24 @@ def get_args():
3225
type=str,
3326
default='./examples/both.png',
3427
help='Input image path')
35-
parser.add_argument('--aug_smooth', action='store_true',
28+
parser.add_argument('--aug-smooth', action='store_true',
3629
help='Apply test time augmentation to smooth the CAM')
3730
parser.add_argument(
38-
'--eigen_smooth',
31+
'--eigen-smooth',
3932
action='store_true',
40-
help='Reduce noise by taking the first principle componenet'
33+
help='Reduce noise by taking the first principle component'
4134
'of cam_weights*activations')
4235
parser.add_argument('--method', type=str, default='gradcam',
43-
choices=['gradcam', 'hirescam', 'gradcam++',
44-
'scorecam', 'xgradcam',
45-
'ablationcam', 'eigencam',
46-
'eigengradcam', 'layercam', 'fullgrad'],
47-
help='Can be gradcam/gradcam++/scorecam/xgradcam'
48-
'/ablationcam/eigencam/eigengradcam/layercam')
49-
36+
choices=[
37+
'gradcam', 'hirescam', 'gradcam++',
38+
'scorecam', 'xgradcam', 'ablationcam',
39+
'eigencam', 'eigengradcam', 'layercam',
40+
'fullgrad', 'gradcamelementwise'
41+
],
42+
help='CAM method')
43+
44+
parser.add_argument('--output-dir', type=str, default='output',
45+
help='Output directory to save the images')
5046
args = parser.parse_args()
5147
args.use_cuda = args.use_cuda and torch.cuda.is_available()
5248
if args.use_cuda:
@@ -59,25 +55,26 @@ def get_args():
5955

6056
if __name__ == '__main__':
6157
""" python cam.py -image-path <path_to_image>
62-
Example usage of loading an image, and computing:
58+
Example usage of loading an image and computing:
6359
1. CAM
6460
2. Guided Back Propagation
6561
3. Combining both
6662
"""
6763

6864
args = get_args()
69-
methods = \
70-
{"gradcam": GradCAM,
71-
"hirescam": HiResCAM,
72-
"scorecam": ScoreCAM,
73-
"gradcam++": GradCAMPlusPlus,
74-
"ablationcam": AblationCAM,
75-
"xgradcam": XGradCAM,
76-
"eigencam": EigenCAM,
77-
"eigengradcam": EigenGradCAM,
78-
"layercam": LayerCAM,
79-
"fullgrad": FullGrad,
80-
"gradcamelementwise": GradCAMElementWise}
65+
methods = {
66+
"gradcam": GradCAM,
67+
"hirescam": HiResCAM,
68+
"scorecam": ScoreCAM,
69+
"gradcam++": GradCAMPlusPlus,
70+
"ablationcam": AblationCAM,
71+
"xgradcam": XGradCAM,
72+
"eigencam": EigenCAM,
73+
"eigengradcam": EigenGradCAM,
74+
"layercam": LayerCAM,
75+
"fullgrad": FullGrad,
76+
"gradcamelementwise": GradCAMElementWise
77+
}
8178

8279
model = models.resnet50(pretrained=True)
8380

@@ -93,6 +90,7 @@ def get_args():
9390
# You can also try selecting all layers of a certain type, with e.g:
9491
# from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive
9592
# find_layer_types_recursive(model, [torch.nn.ReLU])
93+
9694
target_layers = [model.layer4]
9795

9896
rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1]
@@ -115,6 +113,7 @@ def get_args():
115113
target_layers=target_layers,
116114
use_cuda=args.use_cuda) as cam:
117115

116+
118117
# AblationCAM and ScoreCAM have batched implementations.
119118
# You can override the internal batch size for faster computation.
120119
cam.batch_size = 32
@@ -123,12 +122,9 @@ def get_args():
123122
aug_smooth=args.aug_smooth,
124123
eigen_smooth=args.eigen_smooth)
125124

126-
# Here grayscale_cam has only one image in the batch
127125
grayscale_cam = grayscale_cam[0, :]
128126

129127
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
130-
131-
# cam_image is RGB encoded whereas "cv2.imwrite" requires BGR encoding.
132128
cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
133129

134130
gb_model = GuidedBackpropReLUModel(model=model, use_cuda=args.use_cuda)
@@ -138,6 +134,12 @@ def get_args():
138134
cam_gb = deprocess_image(cam_mask * gb)
139135
gb = deprocess_image(gb)
140136

141-
cv2.imwrite(f'{args.method}_cam.jpg', cam_image)
142-
cv2.imwrite(f'{args.method}_gb.jpg', gb)
143-
cv2.imwrite(f'{args.method}_cam_gb.jpg', cam_gb)
137+
os.makedirs(args.output_dir, exist_ok=True)
138+
139+
cam_output_path = os.path.join(args.output_dir, f'{args.method}_cam.jpg')
140+
gb_output_path = os.path.join(args.output_dir, f'{args.method}_gb.jpg')
141+
cam_gb_output_path = os.path.join(args.output_dir, f'{args.method}_cam_gb.jpg')
142+
143+
cv2.imwrite(cam_output_path, cam_image)
144+
cv2.imwrite(gb_output_path, gb)
145+
cv2.imwrite(cam_gb_output_path, cam_gb)

0 commit comments

Comments
 (0)