Skip to content

Commit 68da5c6

Browse files
feat: Add Depth Anything PreProcessor (#5548)
## What type of PR is this? (check all applicable) - [x] Feature ## Have you discussed this change with the InvokeAI team? - [x] Yes ## Have you updated all relevant documentation? - [x] No ## Description - This adds the newly released Depth Anything to InvokeAI. A new node `Depth Anything Processor` has been added to generate depth maps using this new technique. https://depth-anything.github.io - All related checkpoints will be downloaded automatically on first boot. The `DinoV2` models will be loaded to your torch cache dir and the checkpoints pertaining to Depth Anything will be downloaded to `any/annotators/depth_anything`. - Alternatively you can find the checkpoints here and download them to that folder: https://huggingface.co/spaces/LiheYoung/Depth-Anything/tree/main/checkpoints - This depth map can be used with any depth ControlNet model out there but the folks at DepthAnything have also released a custom fine tuned ControlNet model. From my limited testing, I still prefer the original depth model because this one seems to be producing weird artifacts. Not sure if that is a specific problem to Invoke or just the model itself. I'll test more later. Place these in your controlnet folder like your other ControlNets. You can get that here: https://huggingface.co/spaces/LiheYoung/Depth-Anything/tree/main/checkpoints_controlnet - Also available in the LinearUI - DepthAnything has three models `large`, `base` and `small` -- I've defaulted the processor to small but a user can change to the large model if they wish to do so. Small is way faster but obviously somewhat of a lesser quality. - DepthAnything is now the default processor for depth controlnet models. ## Screenshots ![opera_o3jHnWxVRi](https://github.com/invoke-ai/InvokeAI/assets/54517381/573c66f3-1492-45b0-b6df-25756f5e1d1a) ## Merge Plan DO NOT MERGE YET. Test it first and I'm sure the model caching can be done better. Coz I don't think I've done that at all. I would appreciate if @brandonrising or @lstein or anyone can take a look at that part of it.
2 parents 61cf4d4 + f82744b commit 68da5c6

File tree

12 files changed

+1105
-20
lines changed

12 files changed

+1105
-20
lines changed

invokeai/app/invocations/controlnet_image_processors.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
3131
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
3232
from invokeai.app.shared.fields import FieldDescriptions
33+
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
3334

3435
from ...backend.model_management import BaseModelType
3536
from .baseinvocation import (
@@ -602,3 +603,33 @@ def run_processor(self, image: Image.Image):
602603
color_map = cv2.resize(color_map, (width, height), interpolation=cv2.INTER_NEAREST)
603604
color_map = Image.fromarray(color_map)
604605
return color_map
606+
607+
608+
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
609+
610+
611+
@invocation(
612+
"depth_anything_image_processor",
613+
title="Depth Anything Processor",
614+
tags=["controlnet", "depth", "depth anything"],
615+
category="controlnet",
616+
version="1.0.0",
617+
)
618+
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
619+
"""Generates a depth map based on the Depth Anything algorithm"""
620+
621+
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
622+
default="small", description="The size of the depth model to use"
623+
)
624+
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
625+
offload: bool = InputField(default=False)
626+
627+
def run_processor(self, image):
628+
depth_anything_detector = DepthAnythingDetector()
629+
depth_anything_detector.load_model(model_size=self.model_size)
630+
631+
if image.mode == "RGBA":
632+
image = image.convert("RGB")
633+
634+
processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload)
635+
return processed_image
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import pathlib
2+
from typing import Literal, Union
3+
4+
import cv2
5+
import numpy as np
6+
import torch
7+
import torch.nn.functional as F
8+
from einops import repeat
9+
from PIL import Image
10+
from torchvision.transforms import Compose
11+
12+
from invokeai.app.services.config.config_default import InvokeAIAppConfig
13+
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
14+
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
15+
from invokeai.backend.util.devices import choose_torch_device
16+
from invokeai.backend.util.util import download_with_progress_bar
17+
18+
config = InvokeAIAppConfig.get_config()
19+
20+
DEPTH_ANYTHING_MODELS = {
21+
"large": {
22+
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
23+
"local": "any/annotators/depth_anything/depth_anything_vitl14.pth",
24+
},
25+
"base": {
26+
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
27+
"local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
28+
},
29+
"small": {
30+
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
31+
"local": "any/annotators/depth_anything/depth_anything_vits14.pth",
32+
},
33+
}
34+
35+
36+
transform = Compose(
37+
[
38+
Resize(
39+
width=518,
40+
height=518,
41+
resize_target=False,
42+
keep_aspect_ratio=True,
43+
ensure_multiple_of=14,
44+
resize_method="lower_bound",
45+
image_interpolation_method=cv2.INTER_CUBIC,
46+
),
47+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
48+
PrepareForNet(),
49+
]
50+
)
51+
52+
53+
class DepthAnythingDetector:
54+
def __init__(self) -> None:
55+
self.model = None
56+
self.model_size: Union[Literal["large", "base", "small"], None] = None
57+
58+
def load_model(self, model_size=Literal["large", "base", "small"]):
59+
DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
60+
if not DEPTH_ANYTHING_MODEL_PATH.exists():
61+
download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
62+
63+
if not self.model or model_size != self.model_size:
64+
del self.model
65+
self.model_size = model_size
66+
67+
match self.model_size:
68+
case "small":
69+
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
70+
case "base":
71+
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
72+
case "large":
73+
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
74+
case _:
75+
raise TypeError("Not a supported model")
76+
77+
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
78+
self.model.eval()
79+
80+
self.model.to(choose_torch_device())
81+
return self.model
82+
83+
def to(self, device):
84+
self.model.to(device)
85+
return self
86+
87+
def __call__(self, image, resolution=512, offload=False):
88+
image = np.array(image, dtype=np.uint8)
89+
image = image[:, :, ::-1] / 255.0
90+
91+
image_height, image_width = image.shape[:2]
92+
image = transform({"image": image})["image"]
93+
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
94+
95+
with torch.no_grad():
96+
depth = self.model(image)
97+
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
98+
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
99+
100+
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
101+
depth_map = Image.fromarray(depth_map)
102+
103+
new_height = int(image_height * (resolution / image_width))
104+
depth_map = depth_map.resize((resolution, new_height))
105+
106+
if offload:
107+
del self.model
108+
109+
return depth_map
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import torch.nn as nn
2+
3+
4+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
5+
scratch = nn.Module()
6+
7+
out_shape1 = out_shape
8+
out_shape2 = out_shape
9+
out_shape3 = out_shape
10+
if len(in_shape) >= 4:
11+
out_shape4 = out_shape
12+
13+
if expand:
14+
out_shape1 = out_shape
15+
out_shape2 = out_shape * 2
16+
out_shape3 = out_shape * 4
17+
if len(in_shape) >= 4:
18+
out_shape4 = out_shape * 8
19+
20+
scratch.layer1_rn = nn.Conv2d(
21+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
22+
)
23+
scratch.layer2_rn = nn.Conv2d(
24+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
25+
)
26+
scratch.layer3_rn = nn.Conv2d(
27+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
28+
)
29+
if len(in_shape) >= 4:
30+
scratch.layer4_rn = nn.Conv2d(
31+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
32+
)
33+
34+
return scratch
35+
36+
37+
class ResidualConvUnit(nn.Module):
38+
"""Residual convolution module."""
39+
40+
def __init__(self, features, activation, bn):
41+
"""Init.
42+
43+
Args:
44+
features (int): number of features
45+
"""
46+
super().__init__()
47+
48+
self.bn = bn
49+
50+
self.groups = 1
51+
52+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
53+
54+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
55+
56+
if self.bn:
57+
self.bn1 = nn.BatchNorm2d(features)
58+
self.bn2 = nn.BatchNorm2d(features)
59+
60+
self.activation = activation
61+
62+
self.skip_add = nn.quantized.FloatFunctional()
63+
64+
def forward(self, x):
65+
"""Forward pass.
66+
67+
Args:
68+
x (tensor): input
69+
70+
Returns:
71+
tensor: output
72+
"""
73+
74+
out = self.activation(x)
75+
out = self.conv1(out)
76+
if self.bn:
77+
out = self.bn1(out)
78+
79+
out = self.activation(out)
80+
out = self.conv2(out)
81+
if self.bn:
82+
out = self.bn2(out)
83+
84+
if self.groups > 1:
85+
out = self.conv_merge(out)
86+
87+
return self.skip_add.add(out, x)
88+
89+
90+
class FeatureFusionBlock(nn.Module):
91+
"""Feature fusion block."""
92+
93+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
94+
"""Init.
95+
96+
Args:
97+
features (int): number of features
98+
"""
99+
super(FeatureFusionBlock, self).__init__()
100+
101+
self.deconv = deconv
102+
self.align_corners = align_corners
103+
104+
self.groups = 1
105+
106+
self.expand = expand
107+
out_features = features
108+
if self.expand:
109+
out_features = features // 2
110+
111+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
112+
113+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
114+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
115+
116+
self.skip_add = nn.quantized.FloatFunctional()
117+
118+
self.size = size
119+
120+
def forward(self, *xs, size=None):
121+
"""Forward pass.
122+
123+
Returns:
124+
tensor: output
125+
"""
126+
output = xs[0]
127+
128+
if len(xs) == 2:
129+
res = self.resConfUnit1(xs[1])
130+
output = self.skip_add.add(output, res)
131+
132+
output = self.resConfUnit2(output)
133+
134+
if (size is None) and (self.size is None):
135+
modifier = {"scale_factor": 2}
136+
elif size is None:
137+
modifier = {"size": self.size}
138+
else:
139+
modifier = {"size": size}
140+
141+
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
142+
143+
output = self.out_conv(output)
144+
145+
return output

0 commit comments

Comments
 (0)