Skip to content

Commit 51ae192

Browse files
committed
Documentation for device change
1 parent 00711a2 commit 51ae192

File tree

5 files changed

+41
-42
lines changed

5 files changed

+41
-42
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,11 @@ input_tensor = # Create an input tensor image for your model..
121121
# Note: input_tensor can be a batch tensor with several images!
122122

123123
# Construct the CAM object once, and then re-use it on many images:
124-
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=args.use_cuda)
124+
cam = GradCAM(model=model, target_layers=target_layers)
125125

126126
# You can also use it within a with statement, to make sure it is freed,
127127
# In case you need to re-create it inside an outer loop:
128-
# with GradCAM(model=model, target_layers=target_layers, use_cuda=args.use_cuda) as cam:
128+
# with GradCAM(model=model, target_layers=target_layers) as cam:
129129
# ...
130130

131131
# We have to specify the target we want to generate
@@ -244,8 +244,8 @@ two smoothing methods are supported:
244244
Usage: `python cam.py --image-path <path_to_image> --method <method> --output-dir <output_dir_path> `
245245

246246

247-
To use with CUDA:
248-
`python cam.py --image-path <path_to_image> --use-cuda --output-dir <output_dir_path> `
247+
To use with a specific device, like cpu, cuda, cuda:0 or mps:
248+
`python cam.py --image-path <path_to_image> --device cuda --output-dir <output_dir_path> `
249249

250250
----------
251251

cam.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55
import torch
66
from torchvision import models
7-
from torchvision.models import ResNet50_Weights
87
from pytorch_grad_cam import (
98
GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
109
AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
@@ -19,7 +18,7 @@
1918

2019
def get_args():
2120
parser = argparse.ArgumentParser()
22-
parser.add_argument('--device', type=str, default=None,
21+
parser.add_argument('--device', type=str, default='cpu',
2322
help='Torch device to use')
2423
parser.add_argument(
2524
'--image-path',
@@ -77,7 +76,7 @@ def get_args():
7776
"gradcamelementwise": GradCAMElementWise
7877
}
7978

80-
model = models.resnet50(weights=ResNet50_Weights.DEFAULT).to(args.device).eval()
79+
model = models.resnet50(pretrained=True).to(torch.device(args.device)).eval()
8180

8281
# Choose the target layer you want to compute the visualization for.
8382
# Usually this will be the last convolutional layer in the model.
@@ -104,16 +103,15 @@ def get_args():
104103
# the Class Activation Maps for.
105104
# If targets is None, the highest scoring category (for every member in the batch) will be used.
106105
# You can target specific categories by
107-
# targets = [e.g ClassifierOutputTarget(281)]
106+
# targets = [ClassifierOutputTarget(281)]
107+
# targets = [ClassifierOutputTarget(281)]
108108
targets = None
109109

110110
# Using the with statement ensures the context is freed, and you can
111111
# recreate different CAM objects in a loop.
112112
cam_algorithm = methods[args.method]
113113
with cam_algorithm(model=model,
114-
target_layers=target_layers,
115-
device=args.device) as cam:
116-
114+
target_layers=target_layers) as cam:
117115

118116
# AblationCAM and ScoreCAM have batched implementations.
119117
# You can override the internal batch size for faster computation.

pytorch_grad_cam/base_cam.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ def __init__(self,
1818
tta_transforms: Optional[tta.Compose] = None) -> None:
1919
self.model = model.eval()
2020
self.target_layers = target_layers
21-
self.device = next(self.model.parameters()).device
2221

22+
# Use the same device as the model.
23+
self.device = next(self.model.parameters()).device
2324
self.reshape_transform = reshape_transform
2425
self.compute_input_gradient = compute_input_gradient
2526
self.uses_gradients = uses_gradients

pytorch_grad_cam/score_cam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def get_cam_weights(self,
2424
upsample = torch.nn.UpsamplingBilinear2d(
2525
size=input_tensor.shape[-2:])
2626
activation_tensor = torch.from_numpy(activations)
27-
activation_tensor = activation_tensor.to(next(self.model.parameters()).device)
27+
activation_tensor = activation_tensor.to(self.device)
2828

2929
upsampled = upsample(activation_tensor)
3030

setup.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
1-
import setuptools
2-
3-
with open('README.md', mode='r', encoding='utf-8') as fh:
4-
long_description = fh.read()
5-
6-
with open("requirements.txt", "r") as f:
7-
requirements = f.readlines()
8-
9-
setuptools.setup(
10-
name='grad-cam',
11-
version='1.4.8',
12-
author='Jacob Gildenblat',
13-
author_email='[email protected]',
14-
description='Many Class Activation Map methods implemented in Pytorch for classification, segmentation, object detection and more',
15-
long_description=long_description,
16-
long_description_content_type='text/markdown',
17-
url='https://github.com/jacobgil/pytorch-grad-cam',
18-
project_urls={
19-
'Bug Tracker': 'https://github.com/jacobgil/pytorch-grad-cam/issues',
20-
},
21-
classifiers=[
22-
'Programming Language :: Python :: 3',
23-
'License :: OSI Approved :: MIT License',
24-
'Operating System :: OS Independent',
25-
],
26-
packages=setuptools.find_packages(
27-
exclude=["*tutorials*"]),
28-
python_requires='>=3.6',
29-
install_requires=requirements)
1+
import setuptools
2+
3+
with open('README.md', mode='r', encoding='utf-8') as fh:
4+
long_description = fh.read()
5+
6+
with open("requirements.txt", "r") as f:
7+
requirements = f.readlines()
8+
9+
setuptools.setup(
10+
name='grad-cam',
11+
version='1.5.0',
12+
author='Jacob Gildenblat',
13+
author_email='[email protected]',
14+
description='Many Class Activation Map methods implemented in Pytorch for classification, segmentation, object detection and more',
15+
long_description=long_description,
16+
long_description_content_type='text/markdown',
17+
url='https://github.com/jacobgil/pytorch-grad-cam',
18+
project_urls={
19+
'Bug Tracker': 'https://github.com/jacobgil/pytorch-grad-cam/issues',
20+
},
21+
classifiers=[
22+
'Programming Language :: Python :: 3',
23+
'License :: OSI Approved :: MIT License',
24+
'Operating System :: OS Independent',
25+
],
26+
packages=setuptools.find_packages(
27+
exclude=["*tutorials*"]),
28+
python_requires='>=3.8',
29+
install_requires=requirements)

0 commit comments

Comments
 (0)