Skip to content

Commit 507f852

Browse files
Add DepthAnythingV2. (#2377)
* Add `DepthAnythingBackbone`. * Add DepthAnythingV2 conversion script. * Update docstrings. Fix loss `None` object bug. * Fix DINOV2 test. * Fix test. * Use numpy as the inputs. * Rename the key of the pyramid outputs in DINOV2. * Resolve comments. * Reenable the quantization check.
1 parent 049b25d commit 507f852

21 files changed

+2038
-16
lines changed

keras_hub/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@
8787
from keras_hub.src.models.densenet.densenet_image_converter import (
8888
DenseNetImageConverter as DenseNetImageConverter,
8989
)
90+
from keras_hub.src.models.depth_anything.depth_anything_image_converter import (
91+
DepthAnythingImageConverter as DepthAnythingImageConverter,
92+
)
9093
from keras_hub.src.models.dinov2.dinov2_image_converter import (
9194
DINOV2ImageConverter as DINOV2ImageConverter,
9295
)

keras_hub/api/models/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,21 @@
166166
from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import (
167167
DenseNetImageClassifierPreprocessor as DenseNetImageClassifierPreprocessor,
168168
)
169+
from keras_hub.src.models.depth_anything.depth_anything_backbone import (
170+
DepthAnythingBackbone as DepthAnythingBackbone,
171+
)
172+
from keras_hub.src.models.depth_anything.depth_anything_depth_estimator import (
173+
DepthAnythingDepthEstimator as DepthAnythingDepthEstimator,
174+
)
175+
from keras_hub.src.models.depth_anything.depth_anything_depth_estimator_preprocessor import (
176+
DepthAnythingDepthEstimatorPreprocessor as DepthAnythingDepthEstimatorPreprocessor,
177+
)
178+
from keras_hub.src.models.depth_estimator import (
179+
DepthEstimator as DepthEstimator,
180+
)
181+
from keras_hub.src.models.depth_estimator_preprocessor import (
182+
DepthEstimatorPreprocessor as DepthEstimatorPreprocessor,
183+
)
169184
from keras_hub.src.models.dinov2.dinov2_backbone import (
170185
DINOV2Backbone as DINOV2Backbone,
171186
)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from keras_hub.src.models.depth_anything.depth_anything_backbone import (
2+
DepthAnythingBackbone,
3+
)
4+
from keras_hub.src.models.depth_anything.depth_anything_presets import (
5+
backbone_presets,
6+
)
7+
from keras_hub.src.utils.preset_utils import register_presets
8+
9+
register_presets(backbone_presets, DepthAnythingBackbone)
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import keras
2+
from keras import layers
3+
4+
from keras_hub.src.api_export import keras_hub_export
5+
from keras_hub.src.models.backbone import Backbone
6+
from keras_hub.src.models.depth_anything.depth_anything_layers import (
7+
DepthAnythingDepthEstimationHead,
8+
)
9+
from keras_hub.src.models.depth_anything.depth_anything_layers import (
10+
DepthAnythingNeck,
11+
)
12+
from keras_hub.src.models.dinov2 import DINOV2Backbone
13+
from keras_hub.src.utils.keras_utils import standardize_data_format
14+
15+
16+
@keras_hub_export("keras_hub.models.DepthAnythingBackbone")
17+
class DepthAnythingBackbone(Backbone):
18+
"""DepthAnything core network with hyperparameters.
19+
20+
DepthAnything offers a powerful monocular depth estimation as described in
21+
[Depth Anything V2](https://arxiv.org/abs/2406.09414).
22+
23+
The default constructor gives a fully customizable, randomly initialized
24+
DepthAnything model with any number of layers, heads, and embedding
25+
dimensions by providing the DINOV2 as the `image_encoder`. To load preset
26+
architectures and weights, use the `from_preset` constructor.
27+
28+
Args:
29+
image_encoder: The DINOV2 image encoder for encoding the input images.
30+
reassemble_factors: List of float. The reassemble factor for each
31+
feature map from the image encoder. The length of the list must be
32+
equal to the number of feature maps from the image encoder.
33+
neck_hidden_dims: int. The size of the neck hidden state.
34+
fusion_hidden_dim: int. The size of the fusion hidden state.
35+
head_hidden_dim: int. The size of the neck hidden state.
36+
head_in_index: int. The index to select the feature from the neck
37+
features as the input to the head.
38+
feature_keys: List of string. The keys to select the feature maps from
39+
the image encoder. If `None`, all feature maps from the image
40+
encoder will be used. Defaults to `None`.
41+
data_format: `None` or str. If specified, either `"channels_last"` or
42+
`"channels_first"`. The ordering of the dimensions in the
43+
inputs. `"channels_last"` corresponds to inputs with shape
44+
`(batch_size, height, width, channels)`
45+
while `"channels_first"` corresponds to inputs with shape
46+
`(batch_size, channels, height, width)`. It defaults to the
47+
`image_data_format` value found in your Keras config file at
48+
`~/.keras/keras.json`. If you never set it, then it will be
49+
`"channels_last"`.
50+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
51+
for the models computations and weights. Note that some
52+
computations, such as softmax and layer normalization will always
53+
be done a float32 precision regardless of dtype.
54+
55+
Example:
56+
```python
57+
# Pretrained DepthAnything model.
58+
input_data = {
59+
"images": np.ones(shape=(1, 518, 518, 3), dtype="float32"),
60+
}
61+
model = keras_hub.models.DepthAnythingBackbone.from_preset(
62+
"depth_anything_v2_small"
63+
)
64+
model(input_data)
65+
66+
# Pretrained DepthAnything model with custom image shape.
67+
input_data = {
68+
"images": np.ones(shape=(1, 224, 224, 3), dtype="float32"),
69+
}
70+
model = keras_hub.models.DepthAnythingBackbone.from_preset(
71+
"depth_anything_v2_small", image_shape=(224, 224, 3)
72+
)
73+
model(input_data)
74+
75+
# Randomly initialized DepthAnything model with custom config.
76+
image_encoder = keras_hub.models.DINOV2Backbone(
77+
patch_size=14,
78+
num_layers=4,
79+
hidden_dim=32,
80+
num_heads=2,
81+
intermediate_dim=128,
82+
image_shape=(224, 224, 3),
83+
position_embedding_shape=(518, 518),
84+
)
85+
model = keras_hub.models.DepthAnythingBackbone(
86+
image_encoder=image_encoder,
87+
reassemble_factors=[4, 2, 1, 0.5],
88+
neck_hidden_dims=[16, 32, 64, 128],
89+
fusion_hidden_dim=128,
90+
head_hidden_dim=16,
91+
head_in_index=-1,
92+
feature_keys=["Stage1", "Stage2", "Stage3", "Stage4"],
93+
)
94+
model(input_data)
95+
```
96+
"""
97+
98+
def __init__(
99+
self,
100+
image_encoder,
101+
reassemble_factors,
102+
neck_hidden_dims,
103+
fusion_hidden_dim,
104+
head_hidden_dim,
105+
head_in_index,
106+
feature_keys=None,
107+
data_format=None,
108+
dtype=None,
109+
**kwargs,
110+
):
111+
if not isinstance(image_encoder, DINOV2Backbone):
112+
raise ValueError(
113+
"`image_encoder` must be a `DINOV2Backbone`. "
114+
f"Received image_encoder={image_encoder} "
115+
f"(of type {type(image_encoder)})."
116+
)
117+
if feature_keys is not None:
118+
feature_keys = [str(key) for key in feature_keys]
119+
for key in feature_keys:
120+
if key not in image_encoder.pyramid_outputs:
121+
raise ValueError(
122+
"All `feature_keys` must be in "
123+
"`image_encoder.pyramid_outputs`. "
124+
f"Received feature_keys={feature_keys}, but "
125+
"`image_encoder.pyramid_outputs` contains "
126+
f"{list(image_encoder.pyramid_outputs.keys())}."
127+
)
128+
else:
129+
feature_keys = list(image_encoder.pyramid_outputs.keys())
130+
if len(reassemble_factors) != len(feature_keys):
131+
raise ValueError(
132+
"The length of `reassemble_factors` must be equal to the "
133+
"length of `feature_keys`. "
134+
f"Received len(reassemble_factors)={len(reassemble_factors)}, "
135+
f"len(feature_keys)={len(feature_keys)}."
136+
)
137+
data_format = standardize_data_format(data_format)
138+
patch_size = image_encoder.patch_size
139+
backbone_hidden_dim = image_encoder.hidden_dim
140+
image_shape = image_encoder.image_shape
141+
if data_format == "channels_last":
142+
image_size = (image_shape[0], image_shape[1])
143+
else:
144+
image_size = (image_shape[1], image_shape[2])
145+
146+
# === Layers ===
147+
pyramid_outputs = {
148+
key: value
149+
for key, value in image_encoder.pyramid_outputs.items()
150+
if key in feature_keys
151+
}
152+
self.feature_extractor = keras.Model(
153+
inputs=image_encoder.inputs,
154+
outputs=pyramid_outputs,
155+
)
156+
self.feature_extractor.dtype_policy = image_encoder.dtype_policy
157+
self.neck = DepthAnythingNeck(
158+
patch_size=patch_size,
159+
image_size=image_size,
160+
backbone_hidden_dim=backbone_hidden_dim,
161+
neck_hidden_dims=neck_hidden_dims,
162+
reassemble_factors=reassemble_factors,
163+
fusion_hidden_dim=fusion_hidden_dim,
164+
num_cls_tokens=1,
165+
num_register_tokens=image_encoder.num_register_tokens,
166+
data_format=data_format,
167+
dtype=dtype,
168+
name="neck",
169+
)
170+
self.head = DepthAnythingDepthEstimationHead(
171+
patch_size=patch_size,
172+
patch_height=image_size[0] // patch_size,
173+
patch_width=image_size[1] // patch_size,
174+
fusion_hidden_dim=fusion_hidden_dim,
175+
head_hidden_dim=head_hidden_dim,
176+
head_in_index=head_in_index,
177+
data_format=data_format,
178+
dtype=dtype,
179+
name="head",
180+
)
181+
182+
# === Functional Model ===
183+
image_input = layers.Input(shape=image_shape, name="images")
184+
features = self.feature_extractor(image_input)
185+
features = self.neck(list(features.values()))
186+
depth_output = self.head(features)
187+
super().__init__(
188+
inputs=image_input,
189+
outputs=depth_output,
190+
dtype=dtype,
191+
**kwargs,
192+
)
193+
194+
# === Config ===
195+
self.image_encoder = image_encoder
196+
self.reassemble_factors = reassemble_factors
197+
self.neck_hidden_dims = neck_hidden_dims
198+
self.fusion_hidden_dim = fusion_hidden_dim
199+
self.head_hidden_dim = head_hidden_dim
200+
self.head_in_index = head_in_index
201+
self.feature_keys = feature_keys
202+
203+
def get_config(self):
204+
config = super().get_config()
205+
config.update(
206+
{
207+
"image_encoder": layers.serialize(self.image_encoder),
208+
"reassemble_factors": self.reassemble_factors,
209+
"neck_hidden_dims": self.neck_hidden_dims,
210+
"fusion_hidden_dim": self.fusion_hidden_dim,
211+
"head_hidden_dim": self.head_hidden_dim,
212+
"head_in_index": self.head_in_index,
213+
"feature_keys": self.feature_keys,
214+
}
215+
)
216+
return config
217+
218+
@classmethod
219+
def from_config(cls, config, custom_objects=None):
220+
config = config.copy()
221+
222+
# Propagate `dtype` to `image_encoder` if needed.
223+
if "dtype" in config and config["dtype"] is not None:
224+
dtype_config = config["dtype"]
225+
if "dtype" not in config["image_encoder"]["config"]:
226+
config["image_encoder"]["config"]["dtype"] = dtype_config
227+
228+
# We expect submodels to be instantiated.
229+
config["image_encoder"] = layers.deserialize(
230+
config["image_encoder"], custom_objects=custom_objects
231+
)
232+
return cls(**config)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import numpy as np
2+
import pytest
3+
4+
from keras_hub.src.models.depth_anything.depth_anything_backbone import (
5+
DepthAnythingBackbone,
6+
)
7+
from keras_hub.src.models.dinov2.dinov2_backbone import DINOV2Backbone
8+
from keras_hub.src.tests.test_case import TestCase
9+
10+
11+
class DepthAnythingBackboneTest(TestCase):
12+
def setUp(self):
13+
image_encoder = DINOV2Backbone(
14+
14,
15+
4,
16+
16,
17+
2,
18+
16 * 4,
19+
1.0,
20+
0,
21+
image_shape=(70, 70, 3),
22+
apply_layernorm=True,
23+
name="image_encoder",
24+
)
25+
self.init_kwargs = {
26+
"image_encoder": image_encoder,
27+
"reassemble_factors": [4, 2, 1, 0.5],
28+
"neck_hidden_dims": [16, 32, 64, 128],
29+
"fusion_hidden_dim": 128,
30+
"head_hidden_dim": 16,
31+
"head_in_index": -1,
32+
"feature_keys": ["stage1", "stage2", "stage3", "stage4"],
33+
}
34+
self.input_data = np.ones((2, 70, 70, 3), dtype="float32")
35+
36+
def test_backbone_basics(self):
37+
self.run_backbone_test(
38+
cls=DepthAnythingBackbone,
39+
init_kwargs=self.init_kwargs,
40+
input_data=self.input_data,
41+
expected_output_shape=(2, 70, 70, 1),
42+
)
43+
44+
@pytest.mark.large
45+
def test_saved_model(self):
46+
self.run_model_saving_test(
47+
cls=DepthAnythingBackbone,
48+
init_kwargs=self.init_kwargs,
49+
input_data=self.input_data,
50+
)
51+
52+
@pytest.mark.kaggle_key_required
53+
@pytest.mark.extra_large
54+
def test_smallest_preset(self):
55+
self.skipTest("Presets are not uploaded yet.")
56+
self.run_preset_test(
57+
cls=DepthAnythingBackbone,
58+
preset="depth_anything_v2_small",
59+
input_data=self.input_data,
60+
expected_output_shape=(2, 70, 70, 1),
61+
)
62+
63+
@pytest.mark.kaggle_key_required
64+
@pytest.mark.extra_large
65+
def test_all_presets(self):
66+
self.skipTest("Presets are not uploaded yet.")
67+
for preset in DepthAnythingBackbone.presets:
68+
self.run_preset_test(
69+
cls=DepthAnythingBackbone,
70+
preset=preset,
71+
input_data=self.input_data,
72+
)

0 commit comments

Comments
 (0)