From bb9db06633a58886a62078815df8079070205f20 Mon Sep 17 00:00:00 2001 From: Yonghye Kwon Date: Sat, 22 Jun 2024 20:04:17 +0900 Subject: [PATCH] upload a model on a device before inference --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9bdfd78c..d6b02ddc 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,8 @@ import torch from depth_anything_v2.dpt import DepthAnythingV2 +DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' + model_configs = { 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, @@ -67,7 +69,7 @@ encoder = 'vitl' # or 'vits', 'vitb', 'vitg' model = DepthAnythingV2(**model_configs[encoder]) model.load_state_dict(torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location='cpu')) -model.eval() +model = model.to(DEVICE).eval() raw_img = cv2.imread('your/image/path') depth = model.infer_image(raw_img) # HxW raw depth map in numpy