1- #include " cublas_v2.h"
21#include " cuda_utils.h"
32
43using namespace std ;
54
65// postprocess (NCHW->NHWC, RGB->BGR, *255, ROUND, uint8)
7- template <typename T>
8- __global__ void postprocess_kernel (uint8_t * output, const T* input, const int batchSize, const int height,
9- const int width, const int channel, const int thread_count) {
6+ __global__ void postprocess_kernel (uint8_t * output, float * input,
7+ const int batchSize, const int height, const int width, const int channel,
8+ const int thread_count)
9+ {
1010 int index = threadIdx .x + blockIdx .x * blockDim .x ;
11- if (index >= thread_count)
12- return ;
11+ if (index >= thread_count) return ;
1312
1413 const int c_idx = index % channel;
1514 int idx = index / channel;
@@ -18,57 +17,38 @@ __global__ void postprocess_kernel(uint8_t* output, const T* input, const int ba
1817 const int h_idx = idx % height;
1918 const int b_idx = idx / height;
2019
21- int g_idx = b_idx * height * width * channel + (2 - c_idx) * height * width + h_idx * width + w_idx;
22- float val = (float )input[g_idx];
23- float tt = val * 255 .f ;
20+ int g_idx = b_idx * height * width * channel + (2 - c_idx)* height * width + h_idx * width + w_idx;
21+ float tt = input[g_idx] * 255 .f ;
2422 if (tt > 255 )
2523 tt = 255 ;
26- if (tt < 0 )
27- tt = 0 ;
28- output[index] = (uint8_t )tt;
24+ output[index] = tt;
2925}
3026
31- template __global__ void postprocess_kernel<float >(uint8_t * output, const float * input, const int batchSize,
32- const int height, const int width, const int channel,
33- const int thread_count);
34- template __global__ void postprocess_kernel<half>(uint8_t * output, const half* input, const int batchSize,
35- const int height, const int width, const int channel,
36- const int thread_count);
37-
38- template <typename T>
39- void postprocess (uint8_t * output, const T* input, int batchSize, int height, int width, int channel,
40- cudaStream_t stream) {
27+ void postprocess (uint8_t * output, float *input, int batchSize, int height, int width, int channel, cudaStream_t stream)
28+ {
4129 int thread_count = batchSize * height * width * channel;
4230 int block = 512 ;
4331 int grid = (thread_count - 1 ) / block + 1 ;
4432
45- postprocess_kernel<T> <<< grid, block, 0 , stream>>> (output, input, batchSize, height, width, channel, thread_count);
33+ postprocess_kernel << < grid, block, 0 , stream >> > (output, input, batchSize, height, width, channel, thread_count);
4634}
4735
48- template void postprocess<float >(uint8_t * output, const float * input, int batchSize, int height, int width, int channel,
49- cudaStream_t stream);
50- template void postprocess<half>(uint8_t * output, const half* input, int batchSize, int height, int width, int channel,
51- cudaStream_t stream);
5236
5337#include " postprocess.hpp"
5438
55- namespace nvinfer1 {
56- int PostprocessPluginV2::enqueue (int batchSize, const void * const * inputs, void * const * outputs, void * workspace,
57- cudaStream_t stream) noexcept {
58- uint8_t * output = (uint8_t *)outputs[0 ];
39+ namespace nvinfer1
40+ {
41+ int PostprocessPluginV2::enqueue (int batchSize, const void * const * inputs, void * const * outputs, void * workspace, cudaStream_t stream) noexcept
42+ {
43+ float * input = (float *)inputs[0 ];
44+ uint8_t * output = (uint8_t *)outputs[0 ];
5945
60- const int H = mPostprocess .H ;
61- const int W = mPostprocess .W ;
62- const int C = mPostprocess .C ;
46+ const int H = mPostprocess .H ;
47+ const int W = mPostprocess .W ;
48+ const int C = mPostprocess .C ;
6349
64- if (mDataType == DataType::kFLOAT ) {
65- const float * input = (const float *)inputs[0 ];
66- postprocess<float >(output, input, batchSize, H, W, C, stream);
67- } else if (mDataType == DataType::kHALF ) {
68- const half* input = (const half*)inputs[0 ];
69- postprocess<half>(output, input, batchSize, H, W, C, stream);
70- }
50+ postprocess (output, input, batchSize, H, W, C, stream);
7151
72- return 0 ;
73- }
74- } // namespace nvinfer1
52+ return 0 ;
53+ }
54+ }
0 commit comments