Skip to content

Commit c2e8786

Browse files
authored
[Community pipeline] Marigold depth estimation update -- align with marigold v0.1.5 (#7524)
* add resample option; check denoise_step; update ckpt path * Add seeding in pipeline to increase reproducibility * fix typo * fix typo
1 parent ca61287 commit c2e8786

File tree

2 files changed

+109
-21
lines changed

2 files changed

+109
-21
lines changed

examples/community/README.md

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,25 @@ This depth estimation pipeline processes a single input image through multiple d
8585

8686
```python
8787
import numpy as np
88+
import torch
8889
from PIL import Image
8990
from diffusers import DiffusionPipeline
9091
from diffusers.utils import load_image
9192

93+
# Original DDIM version (higher quality)
94+
pipe = DiffusionPipeline.from_pretrained(
95+
"prs-eth/marigold-v1-0",
96+
custom_pipeline="marigold_depth_estimation"
97+
# torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float).
98+
# variant="fp16", # (optional) Use with `torch_dtype=torch.float16`, to directly load fp16 checkpoint
99+
)
100+
101+
# (New) LCM version (faster speed)
92102
pipe = DiffusionPipeline.from_pretrained(
93-
"Bingxin/Marigold",
103+
"prs-eth/marigold-lcm-v1-0",
94104
custom_pipeline="marigold_depth_estimation"
95105
# torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float).
106+
# variant="fp16", # (optional) Use with `torch_dtype=torch.float16`, to directly load fp16 checkpoint
96107
)
97108

98109
pipe.to("cuda")
@@ -101,12 +112,21 @@ img_path_or_url = "https://share.phys.ethz.ch/~pf/bingkedata/marigold/pipeline_e
101112
image: Image.Image = load_image(img_path_or_url)
102113

103114
pipeline_output = pipe(
104-
image, # Input image.
115+
image, # Input image.
116+
# ----- recommended setting for DDIM version -----
105117
# denoising_steps=10, # (optional) Number of denoising steps of each inference pass. Default: 10.
106118
# ensemble_size=10, # (optional) Number of inference passes in the ensemble. Default: 10.
119+
# ------------------------------------------------
120+
121+
# ----- recommended setting for LCM version ------
122+
# denoising_steps=4,
123+
# ensemble_size=5,
124+
# -------------------------------------------------
125+
107126
# processing_res=768, # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768.
108127
# match_input_res=True, # (optional) Resize depth prediction to match input resolution.
109128
# batch_size=0, # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0.
129+
# seed=2024, # (optional) Random seed can be set to ensure additional reproducibility. Default: None (unseeded). Note: forcing --batch_size 1 helps to increase reproducibility. To ensure full reproducibility, deterministic mode needs to be used.
110130
# color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral". Set to `None` to skip colormap generation.
111131
# show_progress_bar=True, # (optional) If true, will show progress bars of the inference progress.
112132
)

examples/community/marigold_depth_estimation.py

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
# --------------------------------------------------------------------------
1919

2020

21+
import logging
2122
import math
2223
from typing import Dict, Union
2324

2425
import matplotlib
2526
import numpy as np
2627
import torch
2728
from PIL import Image
29+
from PIL.Image import Resampling
2830
from scipy.optimize import minimize
2931
from torch.utils.data import DataLoader, TensorDataset
3032
from tqdm.auto import tqdm
@@ -34,13 +36,14 @@
3436
AutoencoderKL,
3537
DDIMScheduler,
3638
DiffusionPipeline,
39+
LCMScheduler,
3740
UNet2DConditionModel,
3841
)
3942
from diffusers.utils import BaseOutput, check_min_version
4043

4144

4245
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
43-
check_min_version("0.28.0.dev0")
46+
check_min_version("0.25.0")
4447

4548

4649
class MarigoldDepthOutput(BaseOutput):
@@ -61,6 +64,19 @@ class MarigoldDepthOutput(BaseOutput):
6164
uncertainty: Union[None, np.ndarray]
6265

6366

67+
def get_pil_resample_method(method_str: str) -> Resampling:
68+
resample_method_dic = {
69+
"bilinear": Resampling.BILINEAR,
70+
"bicubic": Resampling.BICUBIC,
71+
"nearest": Resampling.NEAREST,
72+
}
73+
resample_method = resample_method_dic.get(method_str, None)
74+
if resample_method is None:
75+
raise ValueError(f"Unknown resampling method: {resample_method}")
76+
else:
77+
return resample_method
78+
79+
6480
class MarigoldPipeline(DiffusionPipeline):
6581
"""
6682
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
@@ -113,7 +129,9 @@ def __call__(
113129
ensemble_size: int = 10,
114130
processing_res: int = 768,
115131
match_input_res: bool = True,
132+
resample_method: str = "bilinear",
116133
batch_size: int = 0,
134+
seed: Union[int, None] = None,
117135
color_map: str = "Spectral",
118136
show_progress_bar: bool = True,
119137
ensemble_kwargs: Dict = None,
@@ -129,14 +147,18 @@ def __call__(
129147
If set to 0: will not resize at all.
130148
match_input_res (`bool`, *optional*, defaults to `True`):
131149
Resize depth prediction to match input resolution.
132-
Only valid if `limit_input_res` is not None.
150+
Only valid if `processing_res` > 0.
151+
resample_method: (`str`, *optional*, defaults to `bilinear`):
152+
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
133153
denoising_steps (`int`, *optional*, defaults to `10`):
134154
Number of diffusion denoising steps (DDIM) during inference.
135155
ensemble_size (`int`, *optional*, defaults to `10`):
136156
Number of predictions to be ensembled.
137157
batch_size (`int`, *optional*, defaults to `0`):
138158
Inference batch size, no bigger than `num_ensemble`.
139159
If set to 0, the script will automatically decide the proper batch size.
160+
seed (`int`, *optional*, defaults to `None`)
161+
Reproducibility seed.
140162
show_progress_bar (`bool`, *optional*, defaults to `True`):
141163
Display a progress bar of diffusion denoising.
142164
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
@@ -146,8 +168,7 @@ def __call__(
146168
Returns:
147169
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
148170
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
149-
- **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and
150-
values in [0, 1]. None if `color_map` is `None`
171+
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
151172
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
152173
coming from ensembling. None if `ensemble_size = 1`
153174
"""
@@ -158,13 +179,21 @@ def __call__(
158179
if not match_input_res:
159180
assert processing_res is not None, "Value error: `resize_output_back` is only valid with "
160181
assert processing_res >= 0
161-
assert denoising_steps >= 1
162182
assert ensemble_size >= 1
163183

184+
# Check if denoising step is reasonable
185+
self._check_inference_step(denoising_steps)
186+
187+
resample_method: Resampling = get_pil_resample_method(resample_method)
188+
164189
# ----------------- Image Preprocess -----------------
165190
# Resize image
166191
if processing_res > 0:
167-
input_image = self.resize_max_res(input_image, max_edge_resolution=processing_res)
192+
input_image = self.resize_max_res(
193+
input_image,
194+
max_edge_resolution=processing_res,
195+
resample_method=resample_method,
196+
)
168197
# Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
169198
input_image = input_image.convert("RGB")
170199
image = np.asarray(input_image)
@@ -203,9 +232,10 @@ def __call__(
203232
rgb_in=batched_img,
204233
num_inference_steps=denoising_steps,
205234
show_pbar=show_progress_bar,
235+
seed=seed,
206236
)
207-
depth_pred_ls.append(depth_pred_raw.detach().clone())
208-
depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze()
237+
depth_pred_ls.append(depth_pred_raw.detach())
238+
depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze()
209239
torch.cuda.empty_cache() # clear vram cache for ensembling
210240

211241
# ----------------- Test-time ensembling -----------------
@@ -227,7 +257,7 @@ def __call__(
227257
# Resize back to original resolution
228258
if match_input_res:
229259
pred_img = Image.fromarray(depth_pred)
230-
pred_img = pred_img.resize(input_size)
260+
pred_img = pred_img.resize(input_size, resample=resample_method)
231261
depth_pred = np.asarray(pred_img)
232262

233263
# Clip output range
@@ -243,12 +273,32 @@ def __call__(
243273
depth_colored_img = Image.fromarray(depth_colored_hwc)
244274
else:
245275
depth_colored_img = None
276+
246277
return MarigoldDepthOutput(
247278
depth_np=depth_pred,
248279
depth_colored=depth_colored_img,
249280
uncertainty=pred_uncert,
250281
)
251282

283+
def _check_inference_step(self, n_step: int):
284+
"""
285+
Check if denoising step is reasonable
286+
Args:
287+
n_step (`int`): denoising steps
288+
"""
289+
assert n_step >= 1
290+
291+
if isinstance(self.scheduler, DDIMScheduler):
292+
if n_step < 10:
293+
logging.warning(
294+
f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
295+
)
296+
elif isinstance(self.scheduler, LCMScheduler):
297+
if not 1 <= n_step <= 4:
298+
logging.warning(f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps.")
299+
else:
300+
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
301+
252302
def _encode_empty_text(self):
253303
"""
254304
Encode text embedding for empty prompt.
@@ -265,7 +315,13 @@ def _encode_empty_text(self):
265315
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
266316

267317
@torch.no_grad()
268-
def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool) -> torch.Tensor:
318+
def single_infer(
319+
self,
320+
rgb_in: torch.Tensor,
321+
num_inference_steps: int,
322+
seed: Union[int, None],
323+
show_pbar: bool,
324+
) -> torch.Tensor:
269325
"""
270326
Perform an individual depth prediction without ensembling.
271327
@@ -286,10 +342,20 @@ def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar
286342
timesteps = self.scheduler.timesteps # [T]
287343

288344
# Encode image
289-
rgb_latent = self._encode_rgb(rgb_in)
345+
rgb_latent = self.encode_rgb(rgb_in)
290346

291347
# Initial depth map (noise)
292-
depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype) # [B, 4, h, w]
348+
if seed is None:
349+
rand_num_generator = None
350+
else:
351+
rand_num_generator = torch.Generator(device=device)
352+
rand_num_generator.manual_seed(seed)
353+
depth_latent = torch.randn(
354+
rgb_latent.shape,
355+
device=device,
356+
dtype=self.dtype,
357+
generator=rand_num_generator,
358+
) # [B, 4, h, w]
293359

294360
# Batched empty text embedding
295361
if self.empty_text_embed is None:
@@ -314,9 +380,9 @@ def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar
314380
noise_pred = self.unet(unet_input, t, encoder_hidden_states=batch_empty_text_embed).sample # [B, 4, h, w]
315381

316382
# compute the previous noisy sample x_t -> x_t-1
317-
depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample
318-
torch.cuda.empty_cache()
319-
depth = self._decode_depth(depth_latent)
383+
depth_latent = self.scheduler.step(noise_pred, t, depth_latent, generator=rand_num_generator).prev_sample
384+
385+
depth = self.decode_depth(depth_latent)
320386

321387
# clip prediction
322388
depth = torch.clip(depth, -1.0, 1.0)
@@ -325,7 +391,7 @@ def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar
325391

326392
return depth
327393

328-
def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
394+
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
329395
"""
330396
Encode RGB image into latent.
331397
@@ -344,7 +410,7 @@ def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
344410
rgb_latent = mean * self.rgb_latent_scale_factor
345411
return rgb_latent
346412

347-
def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
413+
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
348414
"""
349415
Decode depth latent into depth map.
350416
@@ -365,7 +431,7 @@ def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
365431
return depth_mean
366432

367433
@staticmethod
368-
def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
434+
def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image:
369435
"""
370436
Resize image to limit maximum edge length while keeping aspect ratio.
371437
@@ -374,6 +440,8 @@ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
374440
Image to be resized.
375441
max_edge_resolution (`int`):
376442
Maximum edge length (pixel).
443+
resample_method (`PIL.Image.Resampling`):
444+
Resampling method used to resize images.
377445
378446
Returns:
379447
`Image.Image`: Resized image.
@@ -384,7 +452,7 @@ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
384452
new_width = int(original_width * downscale_factor)
385453
new_height = int(original_height * downscale_factor)
386454

387-
resized_img = img.resize((new_width, new_height))
455+
resized_img = img.resize((new_width, new_height), resample=resample_method)
388456
return resized_img
389457

390458
@staticmethod

0 commit comments

Comments
 (0)