diff --git a/pytorch_grad_cam/base_cam.py b/pytorch_grad_cam/base_cam.py index 7ee19297..590dc048 100644 --- a/pytorch_grad_cam/base_cam.py +++ b/pytorch_grad_cam/base_cam.py @@ -16,6 +16,8 @@ def __init__(self, reshape_transform: Callable = None, compute_input_gradient: bool = False, uses_gradients: bool = True) -> None: + for params in model.parameters(): + params.requires_grad = True self.model = model.eval() self.target_layers = target_layers self.cuda = use_cuda