-
Notifications
You must be signed in to change notification settings - Fork 357
Expand file tree
/
Copy pathcotracker3.py
More file actions
147 lines (119 loc) · 4.11 KB
/
cotracker3.py
File metadata and controls
147 lines (119 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import sys
import cv2
import time
import numpy as np
import ailia
import onnxruntime as ort
from vis import Visualizer
# import original modules
sys.path.append('../../util')
from arg_utils import get_base_parser, update_parser # noqa: E402
from model_utils import check_and_download_models # noqa: E402
# logger
from logging import getLogger # noqa: E402
logger = getLogger(__name__)
# ======================
# Parameters
# ======================
VIDEO_PATH = 'input.mp4'
SAVE_PATH = 'output.mp4'
# ======================
# Argument Parser Config
# ======================
parser = get_base_parser(
'CoTracker3: Simpler and Better Point Tracking by Pseudo-Labelling Real Videos',
VIDEO_PATH,
SAVE_PATH,
)
parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size")
parser.add_argument(
"--grid_query_frame",
type=int,
default=0,
help="Compute dense and grid tracks starting from this frame",
)
parser.add_argument(
"--backward_tracking",
action="store_true",
help="Compute tracks in both directions, not only forward",
)
parser.add_argument('--onnx', action='store_true', help='execute onnxruntime version.')
args = update_parser(parser)
# ==========================
# MODEL AND OTHER PARAMETERS
# ==========================
WEIGHT_PATH = 'cotracker3.onnx'
MODEL_PATH = 'cotracker3.onnx.prototxt'
REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/cotracker3/'
def read_video_from_path(path):
try:
cap = cv2.VideoCapture(path)
except Exception as e:
print("Error opening video file: ", e)
return None
frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
cap.release()
return np.stack(frames)
def compute(net,video):
if not args.onnx:
result = net.run((video,np.array(args.grid_size ,dtype=np.int64),
np.array(args.grid_query_frame,dtype=np.int64)))
else:
input_name1 = net.get_inputs()[0].name
input_name2 = net.get_inputs()[1].name
input_name3 = net.get_inputs()[2].name
result= net.run([],{input_name1:video,
input_name2:np.array(args.grid_size ,dtype=np.int64),
input_name3:np.array(args.grid_query_frame,dtype=np.int64)})
return result
# ======================
# Main functions
# ======================
def recognize_from_video():
# net initialize
if not args.onnx:
memory_mode = ailia.get_memory_mode(
reduce_constant=True, ignore_input_with_initializer=True,
reduce_interstage=False, reuse_interstage=True)
net = ailia.Net(MODEL_PATH, WEIGHT_PATH, env_id=args.env_id,memory_mode=memory_mode)
else:
net = ort.InferenceSession(WEIGHT_PATH)
# load video
vis = Visualizer( pad_value=120, linewidth=3)
for path in args.input:
video = read_video_from_path(path)
np.transpose(video,(0, 3, 1, 2))
video = np.transpose(video,(0, 3, 1, 2))[np.newaxis, ...].astype(np.float32)
# calculate feature map
logger.info('Start calculating feature map...')
if args.benchmark:
logger.info('BENCHMARK mode')
for i in range(args.benchmark_count):
start = int(round(time.time() * 1000))
result = compute(net,video)
end = int(round(time.time() * 1000))
logger.info(f'\tailia processing time {end - start} ms')
else:
result = compute(net,video)
pred_tracks = np.array(result[0])
pred_visibility = np.array(result[1])
# save a video with predicted tracks
logger.info(f'saved at : {args.savepath}')
vis.visualize(
video,
pred_tracks,
pred_visibility,
args.savepath
)
logger.info('Script finished successfully.')
def main():
# model files check and download
check_and_download_models(WEIGHT_PATH, MODEL_PATH, REMOTE_PATH)
recognize_from_video()
if __name__ == '__main__':
main()