1+ #include " cublas_v2.h"
12#include " cuda_utils.h"
23
34using namespace std ;
45
56// postprocess (NCHW->NHWC, RGB->BGR, *255, ROUND, uint8)
6- __global__ void postprocess_kernel (uint8_t * output, float * input,
7- const int batchSize, const int height, const int width, const int channel,
8- const int thread_count)
9- {
7+ template <typename T>
8+ __global__ void postprocess_kernel (uint8_t * output, const T* input, const int batchSize, const int height,
9+ const int width, const int channel, const int thread_count) {
1010 int index = threadIdx .x + blockIdx .x * blockDim .x ;
11- if (index >= thread_count) return ;
11+ if (index >= thread_count)
12+ return ;
1213
1314 const int c_idx = index % channel;
1415 int idx = index / channel;
@@ -17,38 +18,57 @@ __global__ void postprocess_kernel(uint8_t* output, float* input,
1718 const int h_idx = idx % height;
1819 const int b_idx = idx / height;
1920
20- int g_idx = b_idx * height * width * channel + (2 - c_idx)* height * width + h_idx * width + w_idx;
21- float tt = input[g_idx] * 255 .f ;
21+ int g_idx = b_idx * height * width * channel + (2 - c_idx) * height * width + h_idx * width + w_idx;
22+ float val = (float )input[g_idx];
23+ float tt = val * 255 .f ;
2224 if (tt > 255 )
2325 tt = 255 ;
24- output[index] = tt;
26+ if (tt < 0 )
27+ tt = 0 ;
28+ output[index] = (uint8_t )tt;
2529}
2630
27- void postprocess (uint8_t * output, float *input, int batchSize, int height, int width, int channel, cudaStream_t stream)
28- {
31+ template __global__ void postprocess_kernel<float >(uint8_t * output, const float * input, const int batchSize,
32+ const int height, const int width, const int channel,
33+ const int thread_count);
34+ template __global__ void postprocess_kernel<half>(uint8_t * output, const half* input, const int batchSize,
35+ const int height, const int width, const int channel,
36+ const int thread_count);
37+
38+ template <typename T>
39+ void postprocess (uint8_t * output, const T* input, int batchSize, int height, int width, int channel,
40+ cudaStream_t stream) {
2941 int thread_count = batchSize * height * width * channel;
3042 int block = 512 ;
3143 int grid = (thread_count - 1 ) / block + 1 ;
3244
33- postprocess_kernel << < grid, block, 0 , stream >> > (output, input, batchSize, height, width, channel, thread_count);
45+ postprocess_kernel<T> <<< grid, block, 0 , stream>>> (output, input, batchSize, height, width, channel, thread_count);
3446}
3547
48+ template void postprocess<float >(uint8_t * output, const float * input, int batchSize, int height, int width, int channel,
49+ cudaStream_t stream);
50+ template void postprocess<half>(uint8_t * output, const half* input, int batchSize, int height, int width, int channel,
51+ cudaStream_t stream);
3652
3753#include " postprocess.hpp"
3854
39- namespace nvinfer1
40- {
41- int PostprocessPluginV2::enqueue (int batchSize, const void * const * inputs, void * const * outputs, void * workspace, cudaStream_t stream) noexcept
42- {
43- float * input = (float *)inputs[0 ];
44- uint8_t * output = (uint8_t *)outputs[0 ];
45-
46- const int H = mPostprocess .H ;
47- const int W = mPostprocess .W ;
48- const int C = mPostprocess .C ;
55+ namespace nvinfer1 {
56+ int PostprocessPluginV2::enqueue (int batchSize, const void * const * inputs, void * const * outputs, void * workspace,
57+ cudaStream_t stream) noexcept {
58+ uint8_t * output = (uint8_t *)outputs[0 ];
4959
50- postprocess (output, input, batchSize, H, W, C, stream);
60+ const int H = mPostprocess .H ;
61+ const int W = mPostprocess .W ;
62+ const int C = mPostprocess .C ;
5163
52- return 0 ;
64+ if (mDataType == DataType::kFLOAT ) {
65+ const float * input = (const float *)inputs[0 ];
66+ postprocess<float >(output, input, batchSize, H, W, C, stream);
67+ } else if (mDataType == DataType::kHALF ) {
68+ const half* input = (const half*)inputs[0 ];
69+ postprocess<half>(output, input, batchSize, H, W, C, stream);
5370 }
54- }
71+
72+ return 0 ;
73+ }
74+ } // namespace nvinfer1
0 commit comments