Skip to content

Commit 00711a2

Browse files
soberhoferSamuel Oberhofer
andauthored
Infer the device from the model parameters
Infer the device from the model parameters --------- Co-authored-by: Samuel Oberhofer <[email protected]>
1 parent a797af2 commit 00711a2

20 files changed

+39
-55
lines changed

cam.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import torch
66
from torchvision import models
7+
from torchvision.models import ResNet50_Weights
78
from pytorch_grad_cam import (
89
GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
910
AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
@@ -18,8 +19,8 @@
1819

1920
def get_args():
2021
parser = argparse.ArgumentParser()
21-
parser.add_argument('--use-cuda', action='store_true', default=False,
22-
help='Use NVIDIA GPU acceleration')
22+
parser.add_argument('--device', type=str, default=None,
23+
help='Torch device to use')
2324
parser.add_argument(
2425
'--image-path',
2526
type=str,
@@ -44,9 +45,9 @@ def get_args():
4445
parser.add_argument('--output-dir', type=str, default='output',
4546
help='Output directory to save the images')
4647
args = parser.parse_args()
47-
args.use_cuda = args.use_cuda and torch.cuda.is_available()
48-
if args.use_cuda:
49-
print('Using GPU for acceleration')
48+
49+
if args.device:
50+
print(f'Using device "{args.device}" for acceleration')
5051
else:
5152
print('Using CPU for computation')
5253

@@ -76,7 +77,7 @@ def get_args():
7677
"gradcamelementwise": GradCAMElementWise
7778
}
7879

79-
model = models.resnet50(pretrained=True)
80+
model = models.resnet50(weights=ResNet50_Weights.DEFAULT).to(args.device).eval()
8081

8182
# Choose the target layer you want to compute the visualization for.
8283
# Usually this will be the last convolutional layer in the model.
@@ -97,7 +98,7 @@ def get_args():
9798
rgb_img = np.float32(rgb_img) / 255
9899
input_tensor = preprocess_image(rgb_img,
99100
mean=[0.485, 0.456, 0.406],
100-
std=[0.229, 0.224, 0.225])
101+
std=[0.229, 0.224, 0.225]).to(args.device)
101102

102103
# We have to specify the target we want to generate
103104
# the Class Activation Maps for.
@@ -111,7 +112,7 @@ def get_args():
111112
cam_algorithm = methods[args.method]
112113
with cam_algorithm(model=model,
113114
target_layers=target_layers,
114-
use_cuda=args.use_cuda) as cam:
115+
device=args.device) as cam:
115116

116117

117118
# AblationCAM and ScoreCAM have batched implementations.
@@ -127,7 +128,7 @@ def get_args():
127128
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
128129
cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
129130

130-
gb_model = GuidedBackpropReLUModel(model=model, use_cuda=args.use_cuda)
131+
gb_model = GuidedBackpropReLUModel(model=model, device=args.device)
131132
gb = gb_model(input_tensor, target_category=None)
132133

133134
cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam])

pytorch_grad_cam/ablation_cam.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,13 @@ class AblationCAM(BaseCAM):
2828
def __init__(self,
2929
model: torch.nn.Module,
3030
target_layers: List[torch.nn.Module],
31-
use_cuda: bool = False,
3231
reshape_transform: Callable = None,
3332
ablation_layer: torch.nn.Module = AblationLayer(),
3433
batch_size: int = 32,
3534
ratio_channels_to_ablate: float = 1.0) -> None:
3635

3736
super(AblationCAM, self).__init__(model,
3837
target_layers,
39-
use_cuda,
4038
reshape_transform,
4139
uses_gradients=False)
4240
self.batch_size = batch_size

pytorch_grad_cam/ablation_cam_multilayer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def replace_layer_recursive(model, old_layer, new_layer):
5757

5858

5959
class AblationCAM(BaseCAM):
60-
def __init__(self, model, target_layers, use_cuda=False,
60+
def __init__(self, model, target_layers,
6161
reshape_transform=None):
62-
super(AblationCAM, self).__init__(model, target_layers, use_cuda,
62+
super(AblationCAM, self).__init__(model, target_layers,
6363
reshape_transform)
6464

6565
if len(target_layers) > 1:

pytorch_grad_cam/base_cam.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,14 @@ class BaseCAM:
1212
def __init__(self,
1313
model: torch.nn.Module,
1414
target_layers: List[torch.nn.Module],
15-
use_cuda: bool = False,
1615
reshape_transform: Callable = None,
1716
compute_input_gradient: bool = False,
1817
uses_gradients: bool = True,
1918
tta_transforms: Optional[tta.Compose] = None) -> None:
2019
self.model = model.eval()
2120
self.target_layers = target_layers
22-
self.cuda = use_cuda
23-
if self.cuda:
24-
self.model = model.cuda()
21+
self.device = next(self.model.parameters()).device
22+
2523
self.reshape_transform = reshape_transform
2624
self.compute_input_gradient = compute_input_gradient
2725
self.uses_gradients = uses_gradients
@@ -75,8 +73,7 @@ def forward(self,
7573
targets: List[torch.nn.Module],
7674
eigen_smooth: bool = False) -> np.ndarray:
7775

78-
if self.cuda:
79-
input_tensor = input_tensor.cuda()
76+
input_tensor = input_tensor.to(self.device)
8077

8178
if self.compute_input_gradient:
8279
input_tensor = torch.autograd.Variable(input_tensor,

pytorch_grad_cam/eigen_cam.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55

66

77
class EigenCAM(BaseCAM):
8-
def __init__(self, model, target_layers, use_cuda=False,
8+
def __init__(self, model, target_layers,
99
reshape_transform=None):
1010
super(EigenCAM, self).__init__(model,
1111
target_layers,
12-
use_cuda,
1312
reshape_transform,
1413
uses_gradients=False)
1514

pytorch_grad_cam/eigen_grad_cam.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77

88
class EigenGradCAM(BaseCAM):
9-
def __init__(self, model, target_layers, use_cuda=False,
9+
def __init__(self, model, target_layers,
1010
reshape_transform=None):
11-
super(EigenGradCAM, self).__init__(model, target_layers, use_cuda,
11+
super(EigenGradCAM, self).__init__(model, target_layers,
1212
reshape_transform)
1313

1414
def get_cam_image(self,

pytorch_grad_cam/fullgrad_cam.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
class FullGrad(BaseCAM):
12-
def __init__(self, model, target_layers, use_cuda=False,
12+
def __init__(self, model, target_layers,
1313
reshape_transform=None):
1414
if len(target_layers) > 0:
1515
print(
@@ -27,7 +27,6 @@ def layer_with_2D_bias(layer):
2727
self).__init__(
2828
model,
2929
target_layers,
30-
use_cuda,
3130
reshape_transform,
3231
compute_input_gradient=True)
3332
self.bias_data = [self.get_bias_data(

pytorch_grad_cam/grad_cam.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33

44

55
class GradCAM(BaseCAM):
6-
def __init__(self, model, target_layers, use_cuda=False,
6+
def __init__(self, model, target_layers,
77
reshape_transform=None):
88
super(
99
GradCAM,
1010
self).__init__(
1111
model,
1212
target_layers,
13-
use_cuda,
1413
reshape_transform)
1514

1615
def get_cam_weights(self,

pytorch_grad_cam/grad_cam_elementwise.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44

55

66
class GradCAMElementWise(BaseCAM):
7-
def __init__(self, model, target_layers, use_cuda=False,
7+
def __init__(self, model, target_layers,
88
reshape_transform=None):
99
super(
1010
GradCAMElementWise,
1111
self).__init__(
1212
model,
1313
target_layers,
14-
use_cuda,
1514
reshape_transform)
1615

1716
def get_cam_image(self,

pytorch_grad_cam/grad_cam_plusplus.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66

77
class GradCAMPlusPlus(BaseCAM):
8-
def __init__(self, model, target_layers, use_cuda=False,
8+
def __init__(self, model, target_layers,
99
reshape_transform=None):
10-
super(GradCAMPlusPlus, self).__init__(model, target_layers, use_cuda,
10+
super(GradCAMPlusPlus, self).__init__(model, target_layers,
1111
reshape_transform)
1212

1313
def get_cam_weights(self,

0 commit comments

Comments
 (0)