-
Notifications
You must be signed in to change notification settings - Fork 357
Expand file tree
/
Copy pathzoe_depth_util.py
More file actions
58 lines (47 loc) · 1.5 KB
/
zoe_depth_util.py
File metadata and controls
58 lines (47 loc) · 1.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import Tuple
import cv2
import matplotlib
import numpy as np
from einops import rearrange
from PIL import Image
def get_params(arch):
weight_path = f"{arch}.onnx"
model_path = f"{arch}.onnx.prototxt"
return weight_path, model_path
def preprocess(img_orig, input_size) -> Tuple[np.ndarray, np.ndarray]:
resized_img = cv2.resize(img_orig, input_size).astype(np.float32)
resized_img /= 255.0
resized_img_reversed = resized_img[..., ::-1]
resized_img = rearrange(resized_img, "h w c -> 1 c h w")
resized_img_reversed = rearrange(resized_img_reversed, "h w c -> 1 c h w")
return resized_img, resized_img_reversed
def postprocess(
pred,
original_width: int,
original_height: int,
vmin: int = 0,
vmax: int = 10,
cmap: str = "magma_r",
) -> np.ndarray:
invalid_mask = pred == -99
mask = np.logical_not(invalid_mask)
if vmin is None:
vmin = np.percentile(pred[mask], 2)
if vmax is None:
vmax = np.percentile(pred[mask], 85)
pred = (pred - vmin) / (vmax - vmin)
pred[invalid_mask] = np.nan
if hasattr(matplotlib, "colormaps"):
cmapper = matplotlib.colormaps[cmap]
else:
cmapper = matplotlib.cm.get_cmap(cmap)
pred = cmapper(pred, bytes=True)
img = pred[...]
img[invalid_mask] = (128, 128, 128, 256)
img = cv2.resize(img, (original_width, original_height))
return img
def save(
pred: np.ndarray,
output_filename: str,
):
Image.fromarray(pred).save(output_filename)