Skip to content

Commit 6eb0226

Browse files
authored
Add FP16 support for real-esrgan in TensorRT10 (wang-xinyu#1697)
1 parent bbd19a0 commit 6eb0226

File tree

8 files changed

+409
-448
lines changed

8 files changed

+409
-448
lines changed

real-esrgan/CMakeLists.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ include_directories(${PROJECT_SOURCE_DIR}/include)
2020
include_directories(/usr/local/cuda/include)
2121
link_directories(/usr/local/cuda/lib64)
2222
# tensorrt
23-
include_directories(/usr/include/x86_64-linux-gnu/)
24-
link_directories(/usr/lib/x86_64-linux-gnu/)
23+
include_directories(/usr/local/TensorRT-10.10.0.31/include)
24+
link_directories(/usr/local/TensorRT-10.10.0.31/lib)
2525

2626
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Ofast -g -Wfatal-errors -D_MWAITXINTRIN_H_INCLUDED")
2727
cuda_add_library(myplugins SHARED preprocess.cu postprocess.cu)
@@ -40,5 +40,3 @@ target_link_libraries(real-esrgan ${OpenCV_LIBS})
4040
if(UNIX)
4141
add_definitions(-O2 -pthread)
4242
endif(UNIX)
43-
44-

real-esrgan/common.hpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
using namespace nvinfer1;
1212

13+
static const int PRECISION_MODE = 16; // fp32 : 32, fp16 : 16
14+
1315
// TensorRT weight files have a simple space delimited format:
1416
// [type] [size] <data x size in hex>
1517
std::map<std::string, Weights> loadWeights(const std::string file) {
@@ -32,14 +34,22 @@ std::map<std::string, Weights> loadWeights(const std::string file) {
3234
// Read name and type of blob
3335
std::string name;
3436
input >> name >> std::dec >> size;
35-
wt.type = DataType::kFLOAT;
3637

37-
// Load blob
38-
uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
39-
for (uint32_t x = 0, y = size; x < y; ++x) {
40-
input >> std::hex >> val[x];
38+
if (PRECISION_MODE == 16) {
39+
wt.type = DataType::kHALF;
40+
uint16_t* val = reinterpret_cast<uint16_t*>(malloc(sizeof(val) * size));
41+
for (uint32_t x = 0, y = size; x < y; ++x) {
42+
input >> std::hex >> val[x];
43+
}
44+
wt.values = val;
45+
} else {
46+
wt.type = DataType::kFLOAT;
47+
uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
48+
for (uint32_t x = 0, y = size; x < y; ++x) {
49+
input >> std::hex >> val[x];
50+
}
51+
wt.values = val;
4152
}
42-
wt.values = val;
4353

4454
wt.count = size;
4555
weightMap[name] = wt;

real-esrgan/gen_wts.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
import argparse
22
import os
33
import struct
4+
import numpy as np
45
from basicsr.archs.rrdbnet_arch import RRDBNet
56
from realesrgan import RealESRGANer
67
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
78

9+
10+
def float32_to_float16_hex(value):
11+
f16 = np.float16(value)
12+
u16 = np.frombuffer(f16.tobytes(), dtype=np.uint16)[0]
13+
return format(u16, "04x")
14+
15+
816
def main():
917
"""Inference demo for Real-ESRGAN.
1018
"""
1119
parser = argparse.ArgumentParser()
12-
#parser.add_argument('-i', '--input', type=str, default='../TestData3', help='Input image or folder')
20+
# parser.add_argument('-i', '--input', type=str, default='../TestData3', help='Input image or folder')
1321
parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
1422
parser.add_argument(
1523
'-n',
@@ -84,9 +92,13 @@ def main():
8492
f.write("{} {}".format(k, len(vr)))
8593
for vv in vr:
8694
f.write(" ")
87-
f.write(struct.pack(">f", float(vv)).hex())
95+
if args.fp32:
96+
f.write(struct.pack(">f", float(vv)).hex())
97+
else:
98+
f.write(float32_to_float16_hex(float(vv)))
8899
f.write("\n")
89100
print('Completed real-esrgan.wts file!')
90101

102+
91103
if __name__ == '__main__':
92104
main()

real-esrgan/postprocess.cu

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1+
#include "cublas_v2.h"
12
#include "cuda_utils.h"
23

34
using namespace std;
45

56
// postprocess (NCHW->NHWC, RGB->BGR, *255, ROUND, uint8)
6-
__global__ void postprocess_kernel(uint8_t* output, float* input,
7-
const int batchSize, const int height, const int width, const int channel,
8-
const int thread_count)
9-
{
7+
template <typename T>
8+
__global__ void postprocess_kernel(uint8_t* output, const T* input, const int batchSize, const int height,
9+
const int width, const int channel, const int thread_count) {
1010
int index = threadIdx.x + blockIdx.x * blockDim.x;
11-
if (index >= thread_count) return;
11+
if (index >= thread_count)
12+
return;
1213

1314
const int c_idx = index % channel;
1415
int idx = index / channel;
@@ -17,38 +18,57 @@ __global__ void postprocess_kernel(uint8_t* output, float* input,
1718
const int h_idx = idx % height;
1819
const int b_idx = idx / height;
1920

20-
int g_idx = b_idx * height * width * channel + (2 - c_idx)* height * width + h_idx * width + w_idx;
21-
float tt = input[g_idx] * 255.f;
21+
int g_idx = b_idx * height * width * channel + (2 - c_idx) * height * width + h_idx * width + w_idx;
22+
float val = (float)input[g_idx];
23+
float tt = val * 255.f;
2224
if (tt > 255)
2325
tt = 255;
24-
output[index] = tt;
26+
if (tt < 0)
27+
tt = 0;
28+
output[index] = (uint8_t)tt;
2529
}
2630

27-
void postprocess(uint8_t* output, float*input, int batchSize, int height, int width, int channel, cudaStream_t stream)
28-
{
31+
template __global__ void postprocess_kernel<float>(uint8_t* output, const float* input, const int batchSize,
32+
const int height, const int width, const int channel,
33+
const int thread_count);
34+
template __global__ void postprocess_kernel<half>(uint8_t* output, const half* input, const int batchSize,
35+
const int height, const int width, const int channel,
36+
const int thread_count);
37+
38+
template <typename T>
39+
void postprocess(uint8_t* output, const T* input, int batchSize, int height, int width, int channel,
40+
cudaStream_t stream) {
2941
int thread_count = batchSize * height * width * channel;
3042
int block = 512;
3143
int grid = (thread_count - 1) / block + 1;
3244

33-
postprocess_kernel << <grid, block, 0, stream >> > (output, input, batchSize, height, width, channel, thread_count);
45+
postprocess_kernel<T><<<grid, block, 0, stream>>>(output, input, batchSize, height, width, channel, thread_count);
3446
}
3547

48+
template void postprocess<float>(uint8_t* output, const float* input, int batchSize, int height, int width, int channel,
49+
cudaStream_t stream);
50+
template void postprocess<half>(uint8_t* output, const half* input, int batchSize, int height, int width, int channel,
51+
cudaStream_t stream);
3652

3753
#include "postprocess.hpp"
3854

39-
namespace nvinfer1
40-
{
41-
int PostprocessPluginV2::enqueue(int batchSize, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
42-
{
43-
float* input = (float*)inputs[0];
44-
uint8_t* output = (uint8_t*)outputs[0];
45-
46-
const int H = mPostprocess.H;
47-
const int W = mPostprocess.W;
48-
const int C = mPostprocess.C;
55+
namespace nvinfer1 {
56+
int PostprocessPluginV2::enqueue(int batchSize, const void* const* inputs, void* const* outputs, void* workspace,
57+
cudaStream_t stream) noexcept {
58+
uint8_t* output = (uint8_t*)outputs[0];
4959

50-
postprocess(output, input, batchSize, H, W, C, stream);
60+
const int H = mPostprocess.H;
61+
const int W = mPostprocess.W;
62+
const int C = mPostprocess.C;
5163

52-
return 0;
64+
if (mDataType == DataType::kFLOAT) {
65+
const float* input = (const float*)inputs[0];
66+
postprocess<float>(output, input, batchSize, H, W, C, stream);
67+
} else if (mDataType == DataType::kHALF) {
68+
const half* input = (const half*)inputs[0];
69+
postprocess<half>(output, input, batchSize, H, W, C, stream);
5370
}
54-
}
71+
72+
return 0;
73+
}
74+
} // namespace nvinfer1

0 commit comments

Comments
 (0)