Skip to content

Commit 4f0dfbd

Browse files
wip: depth_anything_v2 initial implementation
1 parent b70ac88 commit 4f0dfbd

File tree

13 files changed

+1118
-180
lines changed

13 files changed

+1118
-180
lines changed

invokeai/backend/image_util/depth_anything/__init__.py

Lines changed: 14 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,46 @@
11
from pathlib import Path
22
from typing import Literal
33

4-
import cv2
54
import numpy as np
65
import torch
7-
import torch.nn.functional as F
86
from einops import repeat
97
from PIL import Image
10-
from torchvision.transforms import Compose
118

129
from invokeai.app.services.config.config_default import get_config
13-
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
14-
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
10+
from invokeai.backend.image_util.depth_anything.v2.dpt import DepthAnythingV2
1511
from invokeai.backend.util.logging import InvokeAILogger
1612

1713
config = get_config()
1814
logger = InvokeAILogger.get_logger(config=config)
1915

2016
DEPTH_ANYTHING_MODELS = {
21-
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
22-
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
23-
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
17+
"large": "https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth?download=true",
18+
"base": "https://huggingface.co/depth-anything/Depth-Anything-V2-Base/resolve/main/depth_anything_v2_vitb.pth?download=true",
19+
"small": "https://huggingface.co/depth-anything/Depth-Anything-V2-Small/resolve/main/depth_anything_v2_vits.pth?download=true",
2420
}
2521

2622

27-
transform = Compose(
28-
[
29-
Resize(
30-
width=518,
31-
height=518,
32-
resize_target=False,
33-
keep_aspect_ratio=True,
34-
ensure_multiple_of=14,
35-
resize_method="lower_bound",
36-
image_interpolation_method=cv2.INTER_CUBIC,
37-
),
38-
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
39-
PrepareForNet(),
40-
]
41-
)
42-
43-
4423
class DepthAnythingDetector:
45-
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
24+
def __init__(self, model: DepthAnythingV2, device: torch.device) -> None:
4625
self.model = model
4726
self.device = device
4827

4928
@staticmethod
5029
def load_model(
51-
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
52-
) -> DPT_DINOv2:
30+
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small", "giant"] = "small"
31+
) -> DepthAnythingV2:
5332
match model_size:
5433
case "small":
55-
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
34+
model = DepthAnythingV2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
5635
case "base":
57-
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
36+
model = DepthAnythingV2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
5837
case "large":
59-
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
38+
model = DepthAnythingV2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
39+
case "giant":
40+
model = DepthAnythingV2(encoder="vitg", features=384, out_channels=[1536, 1536, 1536, 1536])
6041

6142
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
6243
model.eval()
63-
6444
model.to(device)
6545
return model
6646

@@ -70,18 +50,13 @@ def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
7050
return image
7151

7252
np_image = np.array(image, dtype=np.uint8)
73-
np_image = np_image[:, :, ::-1] / 255.0
74-
7553
image_height, image_width = np_image.shape[:2]
76-
np_image = transform({"image": np_image})["image"]
77-
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
7854

7955
with torch.no_grad():
80-
depth = self.model(tensor_image)
81-
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
56+
depth = self.model.infer_image(np_image)
8257
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
8358

84-
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
59+
depth_map = repeat(depth, "h w -> h w 3").astype(np.uint8)
8560
depth_map = Image.fromarray(depth_map)
8661

8762
new_height = int(image_height * (resolution / image_width))

0 commit comments

Comments
 (0)