18
18
# --------------------------------------------------------------------------
19
19
20
20
21
+ import logging
21
22
import math
22
23
from typing import Dict , Union
23
24
24
25
import matplotlib
25
26
import numpy as np
26
27
import torch
27
28
from PIL import Image
29
+ from PIL .Image import Resampling
28
30
from scipy .optimize import minimize
29
31
from torch .utils .data import DataLoader , TensorDataset
30
32
from tqdm .auto import tqdm
34
36
AutoencoderKL ,
35
37
DDIMScheduler ,
36
38
DiffusionPipeline ,
39
+ LCMScheduler ,
37
40
UNet2DConditionModel ,
38
41
)
39
42
from diffusers .utils import BaseOutput , check_min_version
40
43
41
44
42
45
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
43
- check_min_version ("0.28.0.dev0 " )
46
+ check_min_version ("0.25.0 " )
44
47
45
48
46
49
class MarigoldDepthOutput (BaseOutput ):
@@ -61,6 +64,19 @@ class MarigoldDepthOutput(BaseOutput):
61
64
uncertainty : Union [None , np .ndarray ]
62
65
63
66
67
+ def get_pil_resample_method (method_str : str ) -> Resampling :
68
+ resample_method_dic = {
69
+ "bilinear" : Resampling .BILINEAR ,
70
+ "bicubic" : Resampling .BICUBIC ,
71
+ "nearest" : Resampling .NEAREST ,
72
+ }
73
+ resample_method = resample_method_dic .get (method_str , None )
74
+ if resample_method is None :
75
+ raise ValueError (f"Unknown resampling method: { resample_method } " )
76
+ else :
77
+ return resample_method
78
+
79
+
64
80
class MarigoldPipeline (DiffusionPipeline ):
65
81
"""
66
82
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
@@ -113,7 +129,9 @@ def __call__(
113
129
ensemble_size : int = 10 ,
114
130
processing_res : int = 768 ,
115
131
match_input_res : bool = True ,
132
+ resample_method : str = "bilinear" ,
116
133
batch_size : int = 0 ,
134
+ seed : Union [int , None ] = None ,
117
135
color_map : str = "Spectral" ,
118
136
show_progress_bar : bool = True ,
119
137
ensemble_kwargs : Dict = None ,
@@ -129,14 +147,18 @@ def __call__(
129
147
If set to 0: will not resize at all.
130
148
match_input_res (`bool`, *optional*, defaults to `True`):
131
149
Resize depth prediction to match input resolution.
132
- Only valid if `limit_input_res` is not None.
150
+ Only valid if `processing_res` > 0.
151
+ resample_method: (`str`, *optional*, defaults to `bilinear`):
152
+ Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
133
153
denoising_steps (`int`, *optional*, defaults to `10`):
134
154
Number of diffusion denoising steps (DDIM) during inference.
135
155
ensemble_size (`int`, *optional*, defaults to `10`):
136
156
Number of predictions to be ensembled.
137
157
batch_size (`int`, *optional*, defaults to `0`):
138
158
Inference batch size, no bigger than `num_ensemble`.
139
159
If set to 0, the script will automatically decide the proper batch size.
160
+ seed (`int`, *optional*, defaults to `None`)
161
+ Reproducibility seed.
140
162
show_progress_bar (`bool`, *optional*, defaults to `True`):
141
163
Display a progress bar of diffusion denoising.
142
164
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
@@ -146,8 +168,7 @@ def __call__(
146
168
Returns:
147
169
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
148
170
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
149
- - **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and
150
- values in [0, 1]. None if `color_map` is `None`
171
+ - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
151
172
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
152
173
coming from ensembling. None if `ensemble_size = 1`
153
174
"""
@@ -158,13 +179,21 @@ def __call__(
158
179
if not match_input_res :
159
180
assert processing_res is not None , "Value error: `resize_output_back` is only valid with "
160
181
assert processing_res >= 0
161
- assert denoising_steps >= 1
162
182
assert ensemble_size >= 1
163
183
184
+ # Check if denoising step is reasonable
185
+ self ._check_inference_step (denoising_steps )
186
+
187
+ resample_method : Resampling = get_pil_resample_method (resample_method )
188
+
164
189
# ----------------- Image Preprocess -----------------
165
190
# Resize image
166
191
if processing_res > 0 :
167
- input_image = self .resize_max_res (input_image , max_edge_resolution = processing_res )
192
+ input_image = self .resize_max_res (
193
+ input_image ,
194
+ max_edge_resolution = processing_res ,
195
+ resample_method = resample_method ,
196
+ )
168
197
# Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
169
198
input_image = input_image .convert ("RGB" )
170
199
image = np .asarray (input_image )
@@ -203,9 +232,10 @@ def __call__(
203
232
rgb_in = batched_img ,
204
233
num_inference_steps = denoising_steps ,
205
234
show_pbar = show_progress_bar ,
235
+ seed = seed ,
206
236
)
207
- depth_pred_ls .append (depth_pred_raw .detach (). clone () )
208
- depth_preds = torch .concat (depth_pred_ls , axis = 0 ).squeeze ()
237
+ depth_pred_ls .append (depth_pred_raw .detach ())
238
+ depth_preds = torch .concat (depth_pred_ls , dim = 0 ).squeeze ()
209
239
torch .cuda .empty_cache () # clear vram cache for ensembling
210
240
211
241
# ----------------- Test-time ensembling -----------------
@@ -227,7 +257,7 @@ def __call__(
227
257
# Resize back to original resolution
228
258
if match_input_res :
229
259
pred_img = Image .fromarray (depth_pred )
230
- pred_img = pred_img .resize (input_size )
260
+ pred_img = pred_img .resize (input_size , resample = resample_method )
231
261
depth_pred = np .asarray (pred_img )
232
262
233
263
# Clip output range
@@ -243,12 +273,32 @@ def __call__(
243
273
depth_colored_img = Image .fromarray (depth_colored_hwc )
244
274
else :
245
275
depth_colored_img = None
276
+
246
277
return MarigoldDepthOutput (
247
278
depth_np = depth_pred ,
248
279
depth_colored = depth_colored_img ,
249
280
uncertainty = pred_uncert ,
250
281
)
251
282
283
+ def _check_inference_step (self , n_step : int ):
284
+ """
285
+ Check if denoising step is reasonable
286
+ Args:
287
+ n_step (`int`): denoising steps
288
+ """
289
+ assert n_step >= 1
290
+
291
+ if isinstance (self .scheduler , DDIMScheduler ):
292
+ if n_step < 10 :
293
+ logging .warning (
294
+ f"Too few denoising steps: { n_step } . Recommended to use the LCM checkpoint for few-step inference."
295
+ )
296
+ elif isinstance (self .scheduler , LCMScheduler ):
297
+ if not 1 <= n_step <= 4 :
298
+ logging .warning (f"Non-optimal setting of denoising steps: { n_step } . Recommended setting is 1-4 steps." )
299
+ else :
300
+ raise RuntimeError (f"Unsupported scheduler type: { type (self .scheduler )} " )
301
+
252
302
def _encode_empty_text (self ):
253
303
"""
254
304
Encode text embedding for empty prompt.
@@ -265,7 +315,13 @@ def _encode_empty_text(self):
265
315
self .empty_text_embed = self .text_encoder (text_input_ids )[0 ].to (self .dtype )
266
316
267
317
@torch .no_grad ()
268
- def single_infer (self , rgb_in : torch .Tensor , num_inference_steps : int , show_pbar : bool ) -> torch .Tensor :
318
+ def single_infer (
319
+ self ,
320
+ rgb_in : torch .Tensor ,
321
+ num_inference_steps : int ,
322
+ seed : Union [int , None ],
323
+ show_pbar : bool ,
324
+ ) -> torch .Tensor :
269
325
"""
270
326
Perform an individual depth prediction without ensembling.
271
327
@@ -286,10 +342,20 @@ def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar
286
342
timesteps = self .scheduler .timesteps # [T]
287
343
288
344
# Encode image
289
- rgb_latent = self ._encode_rgb (rgb_in )
345
+ rgb_latent = self .encode_rgb (rgb_in )
290
346
291
347
# Initial depth map (noise)
292
- depth_latent = torch .randn (rgb_latent .shape , device = device , dtype = self .dtype ) # [B, 4, h, w]
348
+ if seed is None :
349
+ rand_num_generator = None
350
+ else :
351
+ rand_num_generator = torch .Generator (device = device )
352
+ rand_num_generator .manual_seed (seed )
353
+ depth_latent = torch .randn (
354
+ rgb_latent .shape ,
355
+ device = device ,
356
+ dtype = self .dtype ,
357
+ generator = rand_num_generator ,
358
+ ) # [B, 4, h, w]
293
359
294
360
# Batched empty text embedding
295
361
if self .empty_text_embed is None :
@@ -314,9 +380,9 @@ def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar
314
380
noise_pred = self .unet (unet_input , t , encoder_hidden_states = batch_empty_text_embed ).sample # [B, 4, h, w]
315
381
316
382
# compute the previous noisy sample x_t -> x_t-1
317
- depth_latent = self .scheduler .step (noise_pred , t , depth_latent ).prev_sample
318
- torch . cuda . empty_cache ()
319
- depth = self ._decode_depth (depth_latent )
383
+ depth_latent = self .scheduler .step (noise_pred , t , depth_latent , generator = rand_num_generator ).prev_sample
384
+
385
+ depth = self .decode_depth (depth_latent )
320
386
321
387
# clip prediction
322
388
depth = torch .clip (depth , - 1.0 , 1.0 )
@@ -325,7 +391,7 @@ def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar
325
391
326
392
return depth
327
393
328
- def _encode_rgb (self , rgb_in : torch .Tensor ) -> torch .Tensor :
394
+ def encode_rgb (self , rgb_in : torch .Tensor ) -> torch .Tensor :
329
395
"""
330
396
Encode RGB image into latent.
331
397
@@ -344,7 +410,7 @@ def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
344
410
rgb_latent = mean * self .rgb_latent_scale_factor
345
411
return rgb_latent
346
412
347
- def _decode_depth (self , depth_latent : torch .Tensor ) -> torch .Tensor :
413
+ def decode_depth (self , depth_latent : torch .Tensor ) -> torch .Tensor :
348
414
"""
349
415
Decode depth latent into depth map.
350
416
@@ -365,7 +431,7 @@ def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
365
431
return depth_mean
366
432
367
433
@staticmethod
368
- def resize_max_res (img : Image .Image , max_edge_resolution : int ) -> Image .Image :
434
+ def resize_max_res (img : Image .Image , max_edge_resolution : int , resample_method = Resampling . BILINEAR ) -> Image .Image :
369
435
"""
370
436
Resize image to limit maximum edge length while keeping aspect ratio.
371
437
@@ -374,6 +440,8 @@ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
374
440
Image to be resized.
375
441
max_edge_resolution (`int`):
376
442
Maximum edge length (pixel).
443
+ resample_method (`PIL.Image.Resampling`):
444
+ Resampling method used to resize images.
377
445
378
446
Returns:
379
447
`Image.Image`: Resized image.
@@ -384,7 +452,7 @@ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
384
452
new_width = int (original_width * downscale_factor )
385
453
new_height = int (original_height * downscale_factor )
386
454
387
- resized_img = img .resize ((new_width , new_height ))
455
+ resized_img = img .resize ((new_width , new_height ), resample = resample_method )
388
456
return resized_img
389
457
390
458
@staticmethod
0 commit comments