Skip to content

Commit 11e4768

Browse files
authored
Merge pull request #79 from laugh12321/dev
Refactor tensorrt_yolo.infer module and remove deprecated typing types (Tuple, List, Set, Dict)
2 parents 4885cc3 + bd6063e commit 11e4768

File tree

27 files changed

+335
-541
lines changed

27 files changed

+335
-541
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
cmake_minimum_required(VERSION 3.15.0) # 设置CMake的最低版本要求
33
cmake_policy(SET CMP0091 NEW) # 允许在CMake 3.10+中自动设置项目名作为二进制目录名
44
cmake_policy(SET CMP0146 OLD) # 忽略对find_package的过时警告
5-
project(TensorRT-YOLO VERSION 5.1.0 LANGUAGES CXX CUDA) # 定义项目名称、版本和使用的编程语言(C++和CUDA)
5+
project(TensorRT-YOLO VERSION 5.1.1 LANGUAGES CXX CUDA) # 定义项目名称、版本和使用的编程语言(C++和CUDA)
66

77
# 设置 C++ 标准
88
set(CMAKE_CXX_STANDARD 17) # 设置C++标准为17

README.en.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,13 @@ English | [简体中文](README.md)
132132
#### Python Inference Example
133133

134134
> [!NOTE]
135-
> `DeployDet`, `DeployOBB`, `DeploySeg`, and `DeployPose` correspond to detection (Detect), oriented bounding box (OBB), segmentation (Segment), and pose estimation (Pose) models, respectively.
135+
> `DeployDet`, `DeployOBB`, `DeploySeg`, `DeployPose` and `DeployCls` correspond to detection (Detect), oriented bounding box (OBB), segmentation (Segment), pose estimation (Pose) and image classification (Classify) models, respectively.
136136
>
137137
> For these models, the `CG` version utilizes CUDA Graph to further accelerate the inference process, but please note that this feature is limited to static models.
138138

139139
```python
140140
import cv2
141-
from tensorrt_yolo.infer import DeployDet, generate_labels_with_colors, visualize
141+
from tensorrt_yolo.infer import DeployDet, generate_labels, visualize
142142
143143
# Initialize the model
144144
model = DeployDet("yolo11n-with-plugin.engine")
@@ -148,15 +148,15 @@ im = cv2.imread("test_image.jpg")
148148
result = model.predict(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
149149
print(f"==> detect result: {result}")
150150
# Visualization
151-
labels = generate_labels_with_colors("labels.txt")
151+
labels = generate_labels("labels.txt")
152152
vis_im = visualize(im, result, labels)
153153
cv2.imwrite("vis_image.jpg", vis_im)
154154
```
155155

156156
### C++ SDK Quick Start<div id="quick-start-cpp"></div>
157157

158158
> [!NOTE]
159-
> `DeployDet`, `DeployOBB`, `DeploySeg`, and `DeployPose` correspond to detection (Detect), oriented bounding box (OBB), segmentation (Segment), and pose estimation (Pose) models, respectively.
159+
> `DeployDet`, `DeployOBB`, `DeploySeg`, `DeployPose` and `DeployCls` correspond to detection (Detect), oriented bounding box (OBB), segmentation (Segment), pose estimation (Pose) and image classification (Classify) models, respectively.
160160
>
161161
> For these models, the `CG` version utilizes CUDA Graph to further accelerate the inference process, but please note that this feature is limited to static models.
162162

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,13 @@
131131
#### Python 推理示例
132132

133133
> [!NOTE]
134-
> `DeployDet``DeployOBB``DeploySeg``DeployPose` 分别对应于检测(Detect)、方向边界框(OBB)、分割(Segment)和姿态估计(Pose)模型。
134+
> `DeployDet``DeployOBB``DeploySeg``DeployPose``DeployCls` 分别对应于检测(Detect)、方向边界框(OBB)、分割(Segment)、姿态估计(Pose)和图像分类(Classify)模型。
135135
>
136136
> 对于这些模型,`CG` 版本利用 CUDA Graph 来进一步加速推理过程,但请注意,这一功能仅限于静态模型。
137137

138138
```python
139139
import cv2
140-
from tensorrt_yolo.infer import DeployDet, generate_labels_with_colors, visualize
140+
from tensorrt_yolo.infer import DeployDet, generate_labels, visualize
141141
142142
# 初始化模型
143143
model = DeployDet("yolo11n-with-plugin.engine")
@@ -147,15 +147,15 @@ im = cv2.imread("test_image.jpg")
147147
result = model.predict(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
148148
print(f"==> detect result: {result}")
149149
# 可视化
150-
labels = generate_labels_with_colors("labels.txt")
150+
labels = generate_labels("labels.txt")
151151
vis_im = visualize(im, result, labels)
152152
cv2.imwrite("vis_image.jpg", vis_im)
153153
```
154154

155155
### C++ SDK快速开始<div id="quick-start-cpp"></div>
156156

157157
> [!NOTE]
158-
> `DeployDet``DeployOBB``DeploySeg``DeployPose` 分别对应于检测(Detect)、方向边界框(OBB)、分割(Segment)和姿态估计(Pose)模型。
158+
> `DeployDet``DeployOBB``DeploySeg``DeployPose``DeployCls` 分别对应于检测(Detect)、方向边界框(OBB)、分割(Segment)、姿态估计(Pose)和图像分类(Classify)模型。
159159
>
160160
> 对于这些模型,`CG` 版本利用 CUDA Graph 来进一步加速推理过程,但请注意,这一功能仅限于静态模型。
161161

assets/detect.jpg

-1.3 KB
Loading

assets/obb.png

6.56 KB
Loading

assets/pose.jpg

-878 Bytes
Loading

assets/segment.jpg

183 Bytes
Loading

examples/classify/classify.cpp

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <iostream>
44
#include <memory>
55
#include <opencv2/opencv.hpp>
6-
#include <random>
76

87
#include "deploy/utils/utils.hpp"
98
#include "deploy/vision/inference.hpp"
@@ -32,39 +31,31 @@ void createOutputDirectory(const std::string& outputPath) {
3231
}
3332
}
3433

35-
// Generate label and color pairs
36-
std::vector<std::pair<std::string, cv::Scalar>> generateLabelColorPairs(const std::string& labelFile) {
37-
std::ifstream file(labelFile);
38-
std::vector<std::pair<std::string, cv::Scalar>> labelColorPairs;
34+
// Generate label
35+
std::vector<std::string> generateLabels(const std::string& labelFile) {
36+
std::ifstream file(labelFile);
37+
std::vector<std::string> labels;
3938
if (!file.is_open()) {
4039
throw std::runtime_error("Failed to open labels file: " + labelFile);
4140
}
4241

43-
auto generateRandomColor = []() {
44-
std::random_device rd;
45-
std::mt19937 gen(rd());
46-
std::uniform_int_distribution<int> dis(0, 255);
47-
return cv::Scalar(dis(gen), dis(gen), dis(gen));
48-
};
49-
5042
std::string label;
5143
while (std::getline(file, label)) {
52-
labelColorPairs.emplace_back(label, generateRandomColor());
44+
labels.emplace_back(label);
5345
}
54-
return labelColorPairs;
46+
return labels;
5547
}
5648

5749
// Visualize inference results
58-
void visualize(cv::Mat& image, deploy::ClsResult& result, std::vector<std::pair<std::string, cv::Scalar>>& labelColorPairs) {
50+
void visualize(cv::Mat& image, deploy::ClsResult& result, std::vector<std::string>& labels) {
5951
for (size_t i = 0; i < result.num; ++i) {
6052
int cls = result.classes[i];
6153
float score = result.scores[i];
62-
auto& label = labelColorPairs[cls].first;
63-
auto& color = labelColorPairs[cls].second;
64-
std::string labelText = label + " " + cv::format("%.2f", score);
54+
auto& label = labels[cls];
55+
std::string labelText = label + " " + cv::format("%.3f", score);
6556

6657
// Draw rectangle and label
67-
cv::putText(image, labelText, cv::Point(5, 32 + i * 32), cv::FONT_HERSHEY_SIMPLEX, 0.6, color, 1);
58+
cv::putText(image, labelText, cv::Point(5, 32 + i * 32), cv::FONT_HERSHEY_SIMPLEX, 0.6, cv::Scalar(251, 81, 163), 1);
6859
}
6960
}
7061

@@ -119,7 +110,7 @@ int main(int argc, char** argv) {
119110
throw std::runtime_error("Input path does not exist or is not a regular file/directory: " + inputPath);
120111
}
121112

122-
std::vector<std::pair<std::string, cv::Scalar>> labels;
113+
std::vector<std::string> labels;
123114
if (!outputPath.empty()) {
124115
if (labelPath.empty()) {
125116
throw std::runtime_error("Please provide a labels file using -l or --labels.");
@@ -128,7 +119,7 @@ int main(int argc, char** argv) {
128119
throw std::runtime_error("Label path does not exist: " + labelPath);
129120
}
130121

131-
labels = generateLabelColorPairs(labelPath);
122+
labels = generateLabels(labelPath);
132123
createOutputDirectory(outputPath);
133124
}
134125

examples/classify/classify.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from loguru import logger
3131
from rich.progress import track
3232

33-
from tensorrt_yolo.infer import CpuTimer, DeployCGCls, DeployCls, GpuTimer, generate_labels_with_colors, image_batches, visualize
33+
from tensorrt_yolo.infer import CpuTimer, DeployCGCls, DeployCls, GpuTimer, generate_labels, image_batches, visualize
3434

3535

3636
def main():
@@ -54,7 +54,7 @@ def main():
5454
if args.output:
5555
output_dir = Path(args.output)
5656
output_dir.mkdir(parents=True, exist_ok=True)
57-
args.labels = generate_labels_with_colors(args.labels)
57+
args.labels = generate_labels(args.labels)
5858

5959
model = DeployCGCls(args.engine) if args.cudaGraph else DeployCls(args.engine)
6060

examples/detect/detect.cpp

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <iostream>
44
#include <memory>
55
#include <opencv2/opencv.hpp>
6-
#include <random>
76

87
#include "deploy/utils/utils.hpp"
98
#include "deploy/vision/inference.hpp"
@@ -32,44 +31,36 @@ void createOutputDirectory(const std::string& outputPath) {
3231
}
3332
}
3433

35-
// Generate label and color pairs
36-
std::vector<std::pair<std::string, cv::Scalar>> generateLabelColorPairs(const std::string& labelFile) {
37-
std::ifstream file(labelFile);
38-
std::vector<std::pair<std::string, cv::Scalar>> labelColorPairs;
34+
// Generate label
35+
std::vector<std::string> generateLabels(const std::string& labelFile) {
36+
std::ifstream file(labelFile);
37+
std::vector<std::string> labels;
3938
if (!file.is_open()) {
4039
throw std::runtime_error("Failed to open labels file: " + labelFile);
4140
}
4241

43-
auto generateRandomColor = []() {
44-
std::random_device rd;
45-
std::mt19937 gen(rd());
46-
std::uniform_int_distribution<int> dis(0, 255);
47-
return cv::Scalar(dis(gen), dis(gen), dis(gen));
48-
};
49-
5042
std::string label;
5143
while (std::getline(file, label)) {
52-
labelColorPairs.emplace_back(label, generateRandomColor());
44+
labels.emplace_back(label);
5345
}
54-
return labelColorPairs;
46+
return labels;
5547
}
5648

5749
// Visualize inference results
58-
void visualize(cv::Mat& image, deploy::DetResult& result, std::vector<std::pair<std::string, cv::Scalar>>& labelColorPairs) {
50+
void visualize(cv::Mat& image, deploy::DetResult& result, std::vector<std::string>& labels) {
5951
for (size_t i = 0; i < result.num; ++i) {
6052
auto& box = result.boxes[i];
6153
int cls = result.classes[i];
6254
float score = result.scores[i];
63-
auto& label = labelColorPairs[cls].first;
64-
auto& color = labelColorPairs[cls].second;
65-
std::string labelText = label + " " + cv::format("%.2f", score);
55+
auto& label = labels[cls];
56+
std::string labelText = label + " " + cv::format("%.3f", score);
6657

6758
// Draw rectangle and label
6859
int baseLine;
6960
cv::Size labelSize = cv::getTextSize(labelText, cv::FONT_HERSHEY_SIMPLEX, 0.6, 1, &baseLine);
70-
cv::rectangle(image, cv::Point(box.left, box.top), cv::Point(box.right, box.bottom), color, 2, cv::LINE_AA);
71-
cv::rectangle(image, cv::Point(box.left, box.top - labelSize.height), cv::Point(box.left + labelSize.width, box.top), color, -1);
72-
cv::putText(image, labelText, cv::Point(box.left, box.top), cv::FONT_HERSHEY_SIMPLEX, 0.6, cv::Scalar(255, 255, 255), 1);
61+
cv::rectangle(image, cv::Point(box.left, box.top), cv::Point(box.right, box.bottom), cv::Scalar(251, 81, 163), 2, cv::LINE_AA);
62+
cv::rectangle(image, cv::Point(box.left, box.top - labelSize.height), cv::Point(box.left + labelSize.width, box.top), cv::Scalar(125, 40, 81), -1);
63+
cv::putText(image, labelText, cv::Point(box.left, box.top), cv::FONT_HERSHEY_SIMPLEX, 0.6, cv::Scalar(253, 168, 208), 1);
7364
}
7465
}
7566

@@ -124,7 +115,7 @@ int main(int argc, char** argv) {
124115
throw std::runtime_error("Input path does not exist or is not a regular file/directory: " + inputPath);
125116
}
126117

127-
std::vector<std::pair<std::string, cv::Scalar>> labels;
118+
std::vector<std::string> labels;
128119
if (!outputPath.empty()) {
129120
if (labelPath.empty()) {
130121
throw std::runtime_error("Please provide a labels file using -l or --labels.");
@@ -133,7 +124,7 @@ int main(int argc, char** argv) {
133124
throw std::runtime_error("Label path does not exist: " + labelPath);
134125
}
135126

136-
labels = generateLabelColorPairs(labelPath);
127+
labels = generateLabels(labelPath);
137128
createOutputDirectory(outputPath);
138129
}
139130

0 commit comments

Comments
 (0)