-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMiDasDepth.py
More file actions
33 lines (28 loc) · 876 Bytes
/
MiDasDepth.py
File metadata and controls
33 lines (28 loc) · 876 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import cv2
import torch
import matplotlib.pyplot as plt
def getDepth(transform,frame,model):
#transform input for midas
img = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
imgbatch = None
if torch.cuda.is_available():
imgbatch = transform(img).to('cuda')
else:
imgbatch = transform(img).to('cpu')
#prediction
with torch.no_grad():
prediction = model(imgbatch)
#resize output (upscale)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size = img.shape[:2],
mode = 'bicubic',
align_corners = False
).squeeze()
#get numpy value back
output = prediction.cpu().numpy()
#print(output)
return output
#plt.imshow(output)
#plt.pause(0.00001)
#plt.show()