Skip to content

Commit 8f5e2cb

Browse files
blessedcoolanthipsterusername
authored andcommitted
feat: Add Depth Anything PreProcessor
1 parent 2aed6e2 commit 8f5e2cb

File tree

5 files changed

+695
-0
lines changed

5 files changed

+695
-0
lines changed

invokeai/app/invocations/controlnet_image_processors.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
3131
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
3232
from invokeai.app.shared.fields import FieldDescriptions
33+
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
3334

3435
from ...backend.model_management import BaseModelType
3536
from .baseinvocation import (
@@ -602,3 +603,32 @@ def run_processor(self, image: Image.Image):
602603
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
603604
color_map = Image.fromarray(color_map)
604605
return color_map
606+
607+
608+
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
609+
610+
611+
@invocation(
612+
"depth_anything_image_processor",
613+
title="Depth Anything Processor",
614+
tags=["controlnet", "depth", "depth anything"],
615+
category="controlnet",
616+
version="1.0.0",
617+
)
618+
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
619+
"""Generates a depth map based on the Depth Anything algorithm"""
620+
621+
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
622+
default="large", description="The size of the depth model to use"
623+
)
624+
offload: bool = InputField(default=False)
625+
626+
def run_processor(self, image):
627+
depth_anything_detector = DepthAnythingDetector()
628+
depth_anything_detector.load_model(model_size=self.model_size)
629+
630+
if image.mode == "RGBA":
631+
image = image.convert("RGB")
632+
633+
processed_image = depth_anything_detector(image=image, offload=self.offload)
634+
return processed_image
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import pathlib
2+
from typing import Literal, Union
3+
4+
import cv2
5+
import numpy as np
6+
import torch
7+
import torch.nn.functional as F
8+
from einops import repeat
9+
from PIL import Image
10+
from torchvision.transforms import Compose
11+
12+
from invokeai.app.services.config.config_default import InvokeAIAppConfig
13+
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
14+
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
15+
from invokeai.backend.util.devices import choose_torch_device
16+
from invokeai.backend.util.util import download_with_progress_bar
17+
18+
config = InvokeAIAppConfig.get_config()
19+
20+
DEPTH_ANYTHING_MODELS = {
21+
"large": {
22+
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
23+
"local": "sd-1/controlnet/annotator/depth_anything/depth_anything_vitl14.pth",
24+
},
25+
"base": {
26+
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
27+
"local": "sd-1/controlnet/annotator/depth_anything/depth_anything_vitb14.pth",
28+
},
29+
"small": {
30+
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
31+
"local": "sd-1/controlnet/annotator/depth_anything/depth_anything_vits14.pth",
32+
},
33+
}
34+
35+
36+
transform = Compose(
37+
[
38+
Resize(
39+
width=518,
40+
height=518,
41+
resize_target=False,
42+
keep_aspect_ratio=True,
43+
ensure_multiple_of=14,
44+
resize_method="lower_bound",
45+
image_interpolation_method=cv2.INTER_CUBIC,
46+
),
47+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
48+
PrepareForNet(),
49+
]
50+
)
51+
52+
53+
class DepthAnythingDetector:
54+
def __init__(self) -> None:
55+
self.model = None
56+
self.model_size: Union[Literal["large", "base", "small"], None] = None
57+
58+
def load_model(self, model_size=Literal["large", "base", "small"]):
59+
DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
60+
if not DEPTH_ANYTHING_MODEL_PATH.exists():
61+
download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
62+
63+
if not self.model or model_size != self.model_size:
64+
del self.model
65+
self.model_size = model_size
66+
67+
if self.model_size == "small":
68+
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384], localhub=True)
69+
if self.model_size == "base":
70+
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768], localhub=True)
71+
if self.model_size == "large":
72+
self.model = DPT_DINOv2(
73+
encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024], localhub=True
74+
)
75+
76+
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
77+
self.model.eval()
78+
79+
self.model.to(choose_torch_device())
80+
return self.model
81+
82+
def to(self, device):
83+
self.model.to(device)
84+
return self
85+
86+
def __call__(self, image, offload=False):
87+
image = np.array(image, dtype=np.uint8)
88+
original_width, original_height = image.shape[:2]
89+
image = image[:, :, ::-1] / 255.0
90+
91+
image_width, image_height = image.shape[:2]
92+
image = transform({"image": image})["image"]
93+
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
94+
95+
with torch.no_grad():
96+
depth = self.model(image)
97+
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
98+
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
99+
100+
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
101+
depth_map = Image.fromarray(depth_map)
102+
depth_map = depth_map.resize((original_height, original_width))
103+
104+
if offload:
105+
del self.model
106+
107+
return depth_map
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import torch.nn as nn
2+
3+
4+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
5+
scratch = nn.Module()
6+
7+
out_shape1 = out_shape
8+
out_shape2 = out_shape
9+
out_shape3 = out_shape
10+
if len(in_shape) >= 4:
11+
out_shape4 = out_shape
12+
13+
if expand:
14+
out_shape1 = out_shape
15+
out_shape2 = out_shape * 2
16+
out_shape3 = out_shape * 4
17+
if len(in_shape) >= 4:
18+
out_shape4 = out_shape * 8
19+
20+
scratch.layer1_rn = nn.Conv2d(
21+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
22+
)
23+
scratch.layer2_rn = nn.Conv2d(
24+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
25+
)
26+
scratch.layer3_rn = nn.Conv2d(
27+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
28+
)
29+
if len(in_shape) >= 4:
30+
scratch.layer4_rn = nn.Conv2d(
31+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
32+
)
33+
34+
return scratch
35+
36+
37+
class ResidualConvUnit(nn.Module):
38+
"""Residual convolution module."""
39+
40+
def __init__(self, features, activation, bn):
41+
"""Init.
42+
43+
Args:
44+
features (int): number of features
45+
"""
46+
super().__init__()
47+
48+
self.bn = bn
49+
50+
self.groups = 1
51+
52+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
53+
54+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
55+
56+
if self.bn:
57+
self.bn1 = nn.BatchNorm2d(features)
58+
self.bn2 = nn.BatchNorm2d(features)
59+
60+
self.activation = activation
61+
62+
self.skip_add = nn.quantized.FloatFunctional()
63+
64+
def forward(self, x):
65+
"""Forward pass.
66+
67+
Args:
68+
x (tensor): input
69+
70+
Returns:
71+
tensor: output
72+
"""
73+
74+
out = self.activation(x)
75+
out = self.conv1(out)
76+
if self.bn:
77+
out = self.bn1(out)
78+
79+
out = self.activation(out)
80+
out = self.conv2(out)
81+
if self.bn:
82+
out = self.bn2(out)
83+
84+
if self.groups > 1:
85+
out = self.conv_merge(out)
86+
87+
return self.skip_add.add(out, x)
88+
89+
90+
class FeatureFusionBlock(nn.Module):
91+
"""Feature fusion block."""
92+
93+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
94+
"""Init.
95+
96+
Args:
97+
features (int): number of features
98+
"""
99+
super(FeatureFusionBlock, self).__init__()
100+
101+
self.deconv = deconv
102+
self.align_corners = align_corners
103+
104+
self.groups = 1
105+
106+
self.expand = expand
107+
out_features = features
108+
if self.expand:
109+
out_features = features // 2
110+
111+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
112+
113+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
114+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
115+
116+
self.skip_add = nn.quantized.FloatFunctional()
117+
118+
self.size = size
119+
120+
def forward(self, *xs, size=None):
121+
"""Forward pass.
122+
123+
Returns:
124+
tensor: output
125+
"""
126+
output = xs[0]
127+
128+
if len(xs) == 2:
129+
res = self.resConfUnit1(xs[1])
130+
output = self.skip_add.add(output, res)
131+
132+
output = self.resConfUnit2(output)
133+
134+
if (size is None) and (self.size is None):
135+
modifier = {"scale_factor": 2}
136+
elif size is None:
137+
modifier = {"size": self.size}
138+
else:
139+
modifier = {"size": size}
140+
141+
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
142+
143+
output = self.out_conv(output)
144+
145+
return output

0 commit comments

Comments
 (0)