diff --git a/depth_anything_v2/dpt.py b/depth_anything_v2/dpt.py index 18d3e6f8..908666a2 100644 --- a/depth_anything_v2/dpt.py +++ b/depth_anything_v2/dpt.py @@ -186,6 +186,7 @@ def forward(self, x): @torch.no_grad() def infer_image(self, raw_image, input_size=518): image, (h, w) = self.image2tensor(raw_image, input_size) + self.to(image.device) depth = self.forward(image)