diff --git a/depth_anything_v2/dpt.py b/depth_anything_v2/dpt.py index 18d3e6f8..9fde5427 100644 --- a/depth_anything_v2/dpt.py +++ b/depth_anything_v2/dpt.py @@ -214,8 +214,7 @@ def image2tensor(self, raw_image, input_size=518): image = transform({'image': image})['image'] image = torch.from_numpy(image).unsqueeze(0) - - DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' + DEVICE = next(self.pretrained.parameters()).device image = image.to(DEVICE) return image, (h, w)