Skip to content

Commit c859eb8

Browse files
blessedcoolanthipsterusername
authored andcommitted
fix: lint & other minor issues
1 parent 8f5e2cb commit c859eb8

File tree

2 files changed

+8
-13
lines changed

2 files changed

+8
-13
lines changed

invokeai/backend/image_util/depth_anything/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,11 @@ def load_model(self, model_size=Literal["large", "base", "small"]):
6565
self.model_size = model_size
6666

6767
if self.model_size == "small":
68-
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384], localhub=True)
68+
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
6969
if self.model_size == "base":
70-
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768], localhub=True)
70+
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
7171
if self.model_size == "large":
72-
self.model = DPT_DINOv2(
73-
encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024], localhub=True
74-
)
72+
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
7573

7674
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
7775
self.model.eval()

invokeai/backend/image_util/depth_anything/model/dpt.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ def _make_fusion_block(features, use_bn, size=None):
2222

2323

2424
class DPTHead(nn.Module):
25-
def __init__(
26-
self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False
27-
):
25+
def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
2826
super(DPTHead, self).__init__()
2927

3028
self.nclass = nclass
@@ -138,19 +136,18 @@ def forward(self, out_features, patch_h, patch_w):
138136
class DPT_DINOv2(nn.Module):
139137
def __init__(
140138
self,
139+
features,
140+
out_channels,
141141
encoder="vitl",
142-
features=256,
143-
out_channels=[256, 512, 1024, 1024],
144142
use_bn=False,
145143
use_clstoken=False,
146-
localhub=True,
147144
):
148145
super(DPT_DINOv2, self).__init__()
149146

150147
assert encoder in ["vits", "vitb", "vitl"]
151148

152149
# # in case the Internet connection is not stable, please load the DINOv2 locally
153-
# if localhub:
150+
# if use_local:
154151
# self.pretrained = torch.hub.load(
155152
# torchhub_path / "facebookresearch_dinov2_main",
156153
# "dinov2_{:}14".format(encoder),
@@ -170,7 +167,7 @@ def __init__(
170167

171168
dim = self.pretrained.blocks[0].attn.qkv.in_features
172169

173-
self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
170+
self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
174171

175172
def forward(self, x):
176173
h, w = x.shape[-2:]

0 commit comments

Comments
 (0)