@@ -22,9 +22,7 @@ def _make_fusion_block(features, use_bn, size=None):
2222
2323
2424class DPTHead (nn .Module ):
25- def __init__ (
26- self , nclass , in_channels , features = 256 , use_bn = False , out_channels = [256 , 512 , 1024 , 1024 ], use_clstoken = False
27- ):
25+ def __init__ (self , nclass , in_channels , features , out_channels , use_bn = False , use_clstoken = False ):
2826 super (DPTHead , self ).__init__ ()
2927
3028 self .nclass = nclass
@@ -138,19 +136,18 @@ def forward(self, out_features, patch_h, patch_w):
138136class DPT_DINOv2 (nn .Module ):
139137 def __init__ (
140138 self ,
139+ features ,
140+ out_channels ,
141141 encoder = "vitl" ,
142- features = 256 ,
143- out_channels = [256 , 512 , 1024 , 1024 ],
144142 use_bn = False ,
145143 use_clstoken = False ,
146- localhub = True ,
147144 ):
148145 super (DPT_DINOv2 , self ).__init__ ()
149146
150147 assert encoder in ["vits" , "vitb" , "vitl" ]
151148
152149 # # in case the Internet connection is not stable, please load the DINOv2 locally
153- # if localhub :
150+ # if use_local :
154151 # self.pretrained = torch.hub.load(
155152 # torchhub_path / "facebookresearch_dinov2_main",
156153 # "dinov2_{:}14".format(encoder),
@@ -170,7 +167,7 @@ def __init__(
170167
171168 dim = self .pretrained .blocks [0 ].attn .qkv .in_features
172169
173- self .depth_head = DPTHead (1 , dim , features , use_bn , out_channels = out_channels , use_clstoken = use_clstoken )
170+ self .depth_head = DPTHead (1 , dim , features , out_channels = out_channels , use_bn = use_bn , use_clstoken = use_clstoken )
174171
175172 def forward (self , x ):
176173 h , w = x .shape [- 2 :]
0 commit comments