diff --git a/usage_examples/vit_example.py b/usage_examples/vit_example.py index 5cf51d69..eb89f94f 100644 --- a/usage_examples/vit_example.py +++ b/usage_examples/vit_example.py @@ -98,13 +98,11 @@ def reshape_transform(tensor, height=14, width=14): if args.method == "ablationcam": cam = methods[args.method](model=model, target_layers=target_layers, - use_cuda=args.use_cuda, reshape_transform=reshape_transform, ablation_layer=AblationLayerVit()) else: cam = methods[args.method](model=model, target_layers=target_layers, - use_cuda=args.use_cuda, reshape_transform=reshape_transform) rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1]