Skip to content

Commit 20b8fb1

Browse files
authored
Revert "Add FP16 support for real-esrgan in TensorRT10 (wang-xinyu#1697)"
This reverts commit 6eb0226.
1 parent 6eb0226 commit 20b8fb1

File tree

8 files changed

+448
-409
lines changed

8 files changed

+448
-409
lines changed

real-esrgan/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ include_directories(${PROJECT_SOURCE_DIR}/include)
2020
include_directories(/usr/local/cuda/include)
2121
link_directories(/usr/local/cuda/lib64)
2222
# tensorrt
23-
include_directories(/usr/local/TensorRT-10.10.0.31/include)
24-
link_directories(/usr/local/TensorRT-10.10.0.31/lib)
23+
include_directories(/usr/include/x86_64-linux-gnu/)
24+
link_directories(/usr/lib/x86_64-linux-gnu/)
2525

2626
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Ofast -g -Wfatal-errors -D_MWAITXINTRIN_H_INCLUDED")
2727
cuda_add_library(myplugins SHARED preprocess.cu postprocess.cu)
@@ -40,3 +40,5 @@ target_link_libraries(real-esrgan ${OpenCV_LIBS})
4040
if(UNIX)
4141
add_definitions(-O2 -pthread)
4242
endif(UNIX)
43+
44+

real-esrgan/common.hpp

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
using namespace nvinfer1;
1212

13-
static const int PRECISION_MODE = 16; // fp32 : 32, fp16 : 16
14-
1513
// TensorRT weight files have a simple space delimited format:
1614
// [type] [size] <data x size in hex>
1715
std::map<std::string, Weights> loadWeights(const std::string file) {
@@ -34,22 +32,14 @@ std::map<std::string, Weights> loadWeights(const std::string file) {
3432
// Read name and type of blob
3533
std::string name;
3634
input >> name >> std::dec >> size;
35+
wt.type = DataType::kFLOAT;
3736

38-
if (PRECISION_MODE == 16) {
39-
wt.type = DataType::kHALF;
40-
uint16_t* val = reinterpret_cast<uint16_t*>(malloc(sizeof(val) * size));
41-
for (uint32_t x = 0, y = size; x < y; ++x) {
42-
input >> std::hex >> val[x];
43-
}
44-
wt.values = val;
45-
} else {
46-
wt.type = DataType::kFLOAT;
47-
uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
48-
for (uint32_t x = 0, y = size; x < y; ++x) {
49-
input >> std::hex >> val[x];
50-
}
51-
wt.values = val;
37+
// Load blob
38+
uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
39+
for (uint32_t x = 0, y = size; x < y; ++x) {
40+
input >> std::hex >> val[x];
5241
}
42+
wt.values = val;
5343

5444
wt.count = size;
5545
weightMap[name] = wt;

real-esrgan/gen_wts.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,15 @@
11
import argparse
22
import os
33
import struct
4-
import numpy as np
54
from basicsr.archs.rrdbnet_arch import RRDBNet
65
from realesrgan import RealESRGANer
76
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
87

9-
10-
def float32_to_float16_hex(value):
11-
f16 = np.float16(value)
12-
u16 = np.frombuffer(f16.tobytes(), dtype=np.uint16)[0]
13-
return format(u16, "04x")
14-
15-
168
def main():
179
"""Inference demo for Real-ESRGAN.
1810
"""
1911
parser = argparse.ArgumentParser()
20-
# parser.add_argument('-i', '--input', type=str, default='../TestData3', help='Input image or folder')
12+
#parser.add_argument('-i', '--input', type=str, default='../TestData3', help='Input image or folder')
2113
parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
2214
parser.add_argument(
2315
'-n',
@@ -92,13 +84,9 @@ def main():
9284
f.write("{} {}".format(k, len(vr)))
9385
for vv in vr:
9486
f.write(" ")
95-
if args.fp32:
96-
f.write(struct.pack(">f", float(vv)).hex())
97-
else:
98-
f.write(float32_to_float16_hex(float(vv)))
87+
f.write(struct.pack(">f", float(vv)).hex())
9988
f.write("\n")
10089
print('Completed real-esrgan.wts file!')
10190

102-
10391
if __name__ == '__main__':
10492
main()

real-esrgan/postprocess.cu

Lines changed: 24 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
#include "cublas_v2.h"
21
#include "cuda_utils.h"
32

43
using namespace std;
54

65
// postprocess (NCHW->NHWC, RGB->BGR, *255, ROUND, uint8)
7-
template <typename T>
8-
__global__ void postprocess_kernel(uint8_t* output, const T* input, const int batchSize, const int height,
9-
const int width, const int channel, const int thread_count) {
6+
__global__ void postprocess_kernel(uint8_t* output, float* input,
7+
const int batchSize, const int height, const int width, const int channel,
8+
const int thread_count)
9+
{
1010
int index = threadIdx.x + blockIdx.x * blockDim.x;
11-
if (index >= thread_count)
12-
return;
11+
if (index >= thread_count) return;
1312

1413
const int c_idx = index % channel;
1514
int idx = index / channel;
@@ -18,57 +17,38 @@ __global__ void postprocess_kernel(uint8_t* output, const T* input, const int ba
1817
const int h_idx = idx % height;
1918
const int b_idx = idx / height;
2019

21-
int g_idx = b_idx * height * width * channel + (2 - c_idx) * height * width + h_idx * width + w_idx;
22-
float val = (float)input[g_idx];
23-
float tt = val * 255.f;
20+
int g_idx = b_idx * height * width * channel + (2 - c_idx)* height * width + h_idx * width + w_idx;
21+
float tt = input[g_idx] * 255.f;
2422
if (tt > 255)
2523
tt = 255;
26-
if (tt < 0)
27-
tt = 0;
28-
output[index] = (uint8_t)tt;
24+
output[index] = tt;
2925
}
3026

31-
template __global__ void postprocess_kernel<float>(uint8_t* output, const float* input, const int batchSize,
32-
const int height, const int width, const int channel,
33-
const int thread_count);
34-
template __global__ void postprocess_kernel<half>(uint8_t* output, const half* input, const int batchSize,
35-
const int height, const int width, const int channel,
36-
const int thread_count);
37-
38-
template <typename T>
39-
void postprocess(uint8_t* output, const T* input, int batchSize, int height, int width, int channel,
40-
cudaStream_t stream) {
27+
void postprocess(uint8_t* output, float*input, int batchSize, int height, int width, int channel, cudaStream_t stream)
28+
{
4129
int thread_count = batchSize * height * width * channel;
4230
int block = 512;
4331
int grid = (thread_count - 1) / block + 1;
4432

45-
postprocess_kernel<T><<<grid, block, 0, stream>>>(output, input, batchSize, height, width, channel, thread_count);
33+
postprocess_kernel << <grid, block, 0, stream >> > (output, input, batchSize, height, width, channel, thread_count);
4634
}
4735

48-
template void postprocess<float>(uint8_t* output, const float* input, int batchSize, int height, int width, int channel,
49-
cudaStream_t stream);
50-
template void postprocess<half>(uint8_t* output, const half* input, int batchSize, int height, int width, int channel,
51-
cudaStream_t stream);
5236

5337
#include "postprocess.hpp"
5438

55-
namespace nvinfer1 {
56-
int PostprocessPluginV2::enqueue(int batchSize, const void* const* inputs, void* const* outputs, void* workspace,
57-
cudaStream_t stream) noexcept {
58-
uint8_t* output = (uint8_t*)outputs[0];
39+
namespace nvinfer1
40+
{
41+
int PostprocessPluginV2::enqueue(int batchSize, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
42+
{
43+
float* input = (float*)inputs[0];
44+
uint8_t* output = (uint8_t*)outputs[0];
5945

60-
const int H = mPostprocess.H;
61-
const int W = mPostprocess.W;
62-
const int C = mPostprocess.C;
46+
const int H = mPostprocess.H;
47+
const int W = mPostprocess.W;
48+
const int C = mPostprocess.C;
6349

64-
if (mDataType == DataType::kFLOAT) {
65-
const float* input = (const float*)inputs[0];
66-
postprocess<float>(output, input, batchSize, H, W, C, stream);
67-
} else if (mDataType == DataType::kHALF) {
68-
const half* input = (const half*)inputs[0];
69-
postprocess<half>(output, input, batchSize, H, W, C, stream);
70-
}
50+
postprocess(output, input, batchSize, H, W, C, stream);
7151

72-
return 0;
73-
}
74-
} // namespace nvinfer1
52+
return 0;
53+
}
54+
}

0 commit comments

Comments
 (0)