|
| 1 | +#include <iostream> |
| 2 | +#include <opencv2/opencv.hpp> |
| 3 | + |
| 4 | +using namespace std; |
| 5 | +using namespace cv; |
| 6 | +using namespace dnn; |
| 7 | + |
| 8 | +struct TrackingResult |
| 9 | +{ |
| 10 | + bool isLocated; |
| 11 | + Rect bbox; |
| 12 | + float score; |
| 13 | +}; |
| 14 | + |
| 15 | +class VitTrack |
| 16 | +{ |
| 17 | +public: |
| 18 | + |
| 19 | + VitTrack(const string& model_path, int backend_id = 0, int target_id = 0) |
| 20 | + { |
| 21 | + params.net = model_path; |
| 22 | + params.backend = backend_id; |
| 23 | + params.target = target_id; |
| 24 | + model = TrackerVit::create(params); |
| 25 | + } |
| 26 | + |
| 27 | + void init(const Mat& image, const Rect& roi) |
| 28 | + { |
| 29 | + model->init(image, roi); |
| 30 | + } |
| 31 | + |
| 32 | + TrackingResult infer(const Mat& image) |
| 33 | + { |
| 34 | + TrackingResult result; |
| 35 | + result.isLocated = model->update(image, result.bbox); |
| 36 | + result.score = model->getTrackingScore(); |
| 37 | + return result; |
| 38 | + } |
| 39 | + |
| 40 | +private: |
| 41 | + TrackerVit::Params params; |
| 42 | + Ptr<TrackerVit> model; |
| 43 | +}; |
| 44 | + |
| 45 | +Mat visualize(const Mat& image, const Rect& bbox, float score, bool isLocated, double fps = -1.0, |
| 46 | + const Scalar& box_color = Scalar(0, 255, 0), const Scalar& text_color = Scalar(0, 255, 0), |
| 47 | + double fontScale = 1.0, int fontSize = 1) |
| 48 | +{ |
| 49 | + Mat output = image.clone(); |
| 50 | + int h = output.rows; |
| 51 | + int w = output.cols; |
| 52 | + |
| 53 | + if (fps >= 0) |
| 54 | + { |
| 55 | + putText(output, "FPS: " + to_string(fps), Point(0, 30), FONT_HERSHEY_DUPLEX, fontScale, text_color, fontSize); |
| 56 | + } |
| 57 | + |
| 58 | + if (isLocated && score >= 0.3) |
| 59 | + { |
| 60 | + rectangle(output, bbox, box_color, 2); |
| 61 | + putText(output, format("%.2f", score), Point(bbox.x, bbox.y + 25), |
| 62 | + FONT_HERSHEY_DUPLEX, fontScale, text_color, fontSize); |
| 63 | + } |
| 64 | + else |
| 65 | + { |
| 66 | + Size text_size = getTextSize("Target lost!", FONT_HERSHEY_DUPLEX, fontScale, fontSize, nullptr); |
| 67 | + int text_x = (w - text_size.width) / 2; |
| 68 | + int text_y = (h - text_size.height) / 2; |
| 69 | + putText(output, "Target lost!", Point(text_x, text_y), FONT_HERSHEY_DUPLEX, fontScale, Scalar(0, 0, 255), fontSize); |
| 70 | + } |
| 71 | + |
| 72 | + return output; |
| 73 | +} |
| 74 | + |
| 75 | +int main(int argc, char** argv) |
| 76 | +{ |
| 77 | + CommandLineParser parser(argc, argv, |
| 78 | + "{help h | | Print help message. }" |
| 79 | + "{input i | |Set path to the input video. Omit for using default camera.}" |
| 80 | + "{model_path |object_tracking_vittrack_2023sep.onnx |Set model path}" |
| 81 | + "{backend_target bt |0 |Choose backend-target pair: 0 - OpenCV implementation + CPU, 1 - CUDA + GPU (CUDA), 2 - CUDA + GPU (CUDA FP16), 3 - TIM-VX + NPU, 4 - CANN + NPU}" |
| 82 | + "{save s |false |Specify to save a file with results.}" |
| 83 | + "{vis v |true |Specify to open a new window to show results.}"); |
| 84 | + if (parser.has("help")) |
| 85 | + { |
| 86 | + parser.printMessage(); |
| 87 | + return 0; |
| 88 | + } |
| 89 | + |
| 90 | + string input = parser.get<string>("input"); |
| 91 | + string model_path = parser.get<string>("model_path"); |
| 92 | + int backend_target = parser.get<int>("backend_target"); |
| 93 | + bool save = parser.get<bool>("save"); |
| 94 | + bool vis = parser.get<bool>("vis"); |
| 95 | + |
| 96 | + vector<vector<int>> backend_target_pairs = |
| 97 | + { |
| 98 | + {DNN_BACKEND_OPENCV, DNN_TARGET_CPU}, |
| 99 | + {DNN_BACKEND_CUDA, DNN_TARGET_CUDA}, |
| 100 | + {DNN_BACKEND_CUDA, DNN_TARGET_CUDA_FP16}, |
| 101 | + {DNN_BACKEND_TIMVX, DNN_TARGET_NPU}, |
| 102 | + {DNN_BACKEND_CANN, DNN_TARGET_NPU} |
| 103 | + }; |
| 104 | + |
| 105 | + int backend_id = backend_target_pairs[backend_target][0]; |
| 106 | + int target_id = backend_target_pairs[backend_target][1]; |
| 107 | + |
| 108 | + // Create VitTrack tracker |
| 109 | + VitTrack tracker(model_path, backend_id, target_id); |
| 110 | + |
| 111 | + // Open video capture |
| 112 | + VideoCapture video; |
| 113 | + if (input.empty()) |
| 114 | + { |
| 115 | + video.open(0); // Default camera |
| 116 | + } |
| 117 | + else |
| 118 | + { |
| 119 | + video.open(input); |
| 120 | + } |
| 121 | + |
| 122 | + if (!video.isOpened()) |
| 123 | + { |
| 124 | + cerr << "Error: Could not open video source" << endl; |
| 125 | + return -1; |
| 126 | + } |
| 127 | + |
| 128 | + // Select an object |
| 129 | + Mat first_frame; |
| 130 | + video >> first_frame; |
| 131 | + |
| 132 | + if (first_frame.empty()) |
| 133 | + { |
| 134 | + cerr << "No frames grabbed!" << endl; |
| 135 | + return -1; |
| 136 | + } |
| 137 | + |
| 138 | + Mat first_frame_copy = first_frame.clone(); |
| 139 | + putText(first_frame_copy, "1. Drag a bounding box to track.", Point(0, 25), FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 255, 0)); |
| 140 | + putText(first_frame_copy, "2. Press ENTER to confirm", Point(0, 50), FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 255, 0)); |
| 141 | + Rect roi = selectROI("VitTrack Demo", first_frame_copy); |
| 142 | + |
| 143 | + if (roi.area() == 0) |
| 144 | + { |
| 145 | + cerr << "No ROI is selected! Exiting..." << endl; |
| 146 | + return -1; |
| 147 | + } |
| 148 | + else |
| 149 | + { |
| 150 | + cout << "Selected ROI: " << roi << endl; |
| 151 | + } |
| 152 | + |
| 153 | + // Create VideoWriter if save option is specified |
| 154 | + VideoWriter output_video; |
| 155 | + if (save) |
| 156 | + { |
| 157 | + Size frame_size = first_frame.size(); |
| 158 | + output_video.open("output.mp4", VideoWriter::fourcc('m', 'p', '4', 'v'), video.get(CAP_PROP_FPS), frame_size); |
| 159 | + if (!output_video.isOpened()) |
| 160 | + { |
| 161 | + cerr << "Error: Could not create output video stream" << endl; |
| 162 | + return -1; |
| 163 | + } |
| 164 | + } |
| 165 | + |
| 166 | + // Initialize tracker with ROI |
| 167 | + tracker.init(first_frame, roi); |
| 168 | + |
| 169 | + // Track frame by frame |
| 170 | + TickMeter tm; |
| 171 | + while (waitKey(1) < 0) |
| 172 | + { |
| 173 | + video >> first_frame; |
| 174 | + if (first_frame.empty()) |
| 175 | + { |
| 176 | + cout << "End of video" << endl; |
| 177 | + break; |
| 178 | + } |
| 179 | + |
| 180 | + // Inference |
| 181 | + tm.start(); |
| 182 | + TrackingResult result = tracker.infer(first_frame); |
| 183 | + tm.stop(); |
| 184 | + |
| 185 | + // Visualize |
| 186 | + Mat frame = first_frame.clone(); |
| 187 | + frame = visualize(frame, result.bbox, result.score, result.isLocated, tm.getFPS()); |
| 188 | + |
| 189 | + if (save) |
| 190 | + { |
| 191 | + output_video.write(frame); |
| 192 | + } |
| 193 | + |
| 194 | + if (vis) |
| 195 | + { |
| 196 | + imshow("VitTrack Demo", frame); |
| 197 | + } |
| 198 | + tm.reset(); |
| 199 | + } |
| 200 | + |
| 201 | + if (save) |
| 202 | + { |
| 203 | + output_video.release(); |
| 204 | + } |
| 205 | + |
| 206 | + video.release(); |
| 207 | + destroyAllWindows(); |
| 208 | + |
| 209 | + return 0; |
| 210 | +} |
0 commit comments