Skip to content

Commit 36c0dc1

Browse files
committed
plugin TRT EP init
1 parent 7a635da commit 36c0dc1

35 files changed

+7918
-0
lines changed
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# usage:
2+
# cd build/
3+
# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DORT_HOME=/home/lochi/repos/ort -DTENSORRT_HOME=/home/lochi/tensorrt/TensorRT-10.3.0.26 (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits")
4+
# cmake --build ./ --config Debug
5+
cmake_minimum_required(VERSION 3.26)
6+
project(TensorRTEp VERSION 1.0)
7+
set(CMAKE_CXX_STANDARD 17)
8+
enable_language(CUDA)
9+
file(TO_CMAKE_PATH CUDAToolkit_ROOT "/usr/local/cuda")
10+
find_package(CUDAToolkit REQUIRED)
11+
12+
add_definitions(-DONNX_NAMESPACE=onnx)
13+
add_definitions(-DONNX_ML)
14+
add_definitions(-DNV_TENSORRT_MAJOR=10)
15+
add_definitions(-DNOMINMAX)
16+
file(GLOB tensorrt_src "./*.cc" "./utils/*.cc" "./cuda/unary_elementwise_ops_impl.cu")
17+
add_library(TensorRTEp SHARED ${tensorrt_src})
18+
19+
if (NOT ORT_HOME)
20+
message(FATAL_ERROR "Please specify ORT_HOME, e.g. -DORT_HOME=/path/to/ort/")
21+
endif()
22+
23+
if (NOT TENSORRT_HOME)
24+
message(FATAL_ERROR "Please specify TENSORRT_HOME, e.g. -DTENSORRT_HOME=/path/to/trt/")
25+
endif()
26+
27+
# Use release mode if not specified
28+
if (NOT CMAKE_BUILD_TYPE)
29+
set(CMAKE_BUILD_TYPE "Release")
30+
endif()
31+
32+
# Add dependencies
33+
include(FetchContent)
34+
35+
# Add GSL
36+
FetchContent_Declare(
37+
gsl
38+
GIT_REPOSITORY https://github.com/microsoft/GSL.git
39+
GIT_TAG v4.0.0 # Use a specific tag or commit
40+
)
41+
42+
FetchContent_MakeAvailable(gsl)
43+
44+
# Add flatbuffers
45+
FetchContent_Declare(
46+
flatbuffers
47+
GIT_REPOSITORY https://github.com/google/flatbuffers.git
48+
GIT_TAG v23.5.26 # Use a specific tag or commit
49+
)
50+
51+
FetchContent_MakeAvailable(flatbuffers)
52+
53+
if (WIN32)
54+
set(PLATFORM "Windows")
55+
set(ORT_LIB "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/${CMAKE_BUILD_TYPE}/onnxruntime.lib")
56+
set(DEPS_PATH "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/_deps")
57+
set(TRT_LIBS "${TENSORRT_HOME}/lib/nvinfer_10.lib"
58+
"${TENSORRT_HOME}/lib/nvinfer_plugin_10.lib"
59+
"${TENSORRT_HOME}/lib/nvonnxparser_10.lib")
60+
set(DEPS_LIBS "${DEPS_PATH}/flatbuffers-build/${CMAKE_BUILD_TYPE}/flatbuffers.lib"
61+
"${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx.lib"
62+
"${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib")
63+
64+
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
65+
set(DEPS_LIBS ${DEPS_LIBS}
66+
"${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobufd.lib"
67+
"${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotocd.lib")
68+
else()
69+
set(DEPS_LIBS ${DEPS_LIBS}
70+
"${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobuf.lib"
71+
"${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotoc.lib")
72+
endif()
73+
else()
74+
set(PLATFORM "Linux")
75+
set(ORT_LIB "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/libonnxruntime.so")
76+
set(DEPS_PATH "${ORT_HOME}/build/${PLATFORM}/${CMAKE_BUILD_TYPE}/_deps")
77+
set(TRT_LIBS "${TENSORRT_HOME}/lib/libnvinfer.so"
78+
"${TENSORRT_HOME}/lib/libnvinfer_plugin.so"
79+
"${TENSORRT_HOME}/lib/libnvonnxparser.so")
80+
set(DEPS_LIBS "${DEPS_PATH}/flatbuffers-build/libflatbuffers.a"
81+
"${DEPS_PATH}/onnx-build/libonnx.a"
82+
"${DEPS_PATH}/onnx-build/libonnx_proto.a")
83+
84+
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
85+
set(DEPS_LIBS ${DEPS_LIBS}
86+
"${DEPS_PATH}/protobuf-build/libprotobufd.a"
87+
"${DEPS_PATH}/protobuf-build/libprotocd.a")
88+
else()
89+
set(DEPS_LIBS ${DEPS_LIBS}
90+
"${DEPS_PATH}/protobuf-build/libprotobuf.a"
91+
"${DEPS_PATH}/protobuf-build/libprotoc.a")
92+
endif()
93+
endif()
94+
95+
MESSAGE(STATUS "Looking for following dependencies ...")
96+
MESSAGE(STATUS "Platform : ${PLATFORM}")
97+
MESSAGE(STATUS "ORT home : ${ORT_HOME}")
98+
MESSAGE(STATUS "ORT lib : ${ORT_LIB}")
99+
MESSAGE(STATUS "Deps path: ${DEPS_PATH}")
100+
MESSAGE(STATUS "Deps libs: ${DEPS_LIBS}")
101+
MESSAGE(STATUS "TRT libs : ${TRT_LIBS}")
102+
103+
target_include_directories(TensorRTEp PUBLIC "${ORT_HOME}/include/onnxruntime/core/session/"
104+
"./utils"
105+
"/usr/local/cuda/include"
106+
${TENSORRT_HOME}/include
107+
"${DEPS_PATH}/flatbuffers-src/include"
108+
"${DEPS_PATH}/gsl-src/include"
109+
"${DEPS_PATH}/onnx-src"
110+
"${DEPS_PATH}/onnx-build"
111+
"${DEPS_PATH}/protobuf-src/src"
112+
)
113+
114+
target_link_libraries(TensorRTEp PUBLIC ${ORT_LIB}
115+
${TRT_LIBS}
116+
CUDA::cudart
117+
${DEPS_LIBS}
118+
GSL
119+
flatbuffers)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include <stdint.h>
6+
7+
namespace onnxruntime {
8+
namespace cuda {
9+
10+
// We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer
11+
// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type.
12+
#ifndef CUDA_LONG
13+
#define CUDA_LONG int32_t
14+
#endif
15+
16+
template <class INT, class INT2>
17+
inline __host__ __device__ INT CeilDiv(INT a, INT2 b) // ceil(a/b)
18+
{
19+
return (INT)(((size_t)a + (size_t)b - 1) / (size_t)b); // these size_t casts are necessary since b may be INT_MAX (for maxGridSize[])
20+
}
21+
22+
struct GridDim {
23+
enum : CUDA_LONG {
24+
maxThreadsPerBlock = 256, // max threads per block
25+
maxElementsPerThread = 4, // max element processed per thread
26+
};
27+
};
28+
29+
template <typename InT, typename OutT, typename FuncT, int NumThreadsPerBlock, int NumElementsPerThread>
30+
__global__ void _UnaryElementWise(
31+
const InT* input_data,
32+
OutT* output_data,
33+
const FuncT functor,
34+
CUDA_LONG N) {
35+
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
36+
InT value[NumElementsPerThread];
37+
38+
CUDA_LONG id = start;
39+
#pragma unroll
40+
for (int i = 0; i < NumElementsPerThread; i++) {
41+
if (id < N) {
42+
value[i] = input_data[id];
43+
id += NumThreadsPerBlock;
44+
}
45+
}
46+
47+
id = start;
48+
#pragma unroll
49+
for (int i = 0; i < NumElementsPerThread; i++) {
50+
if (id < N) {
51+
output_data[id] = functor(value[i]);
52+
id += NumThreadsPerBlock;
53+
}
54+
}
55+
}
56+
57+
template <typename InT, typename OutT, typename FuncT>
58+
void UnaryElementWiseImpl(
59+
cudaStream_t stream,
60+
const InT* input_data,
61+
OutT* output_data,
62+
const FuncT& func,
63+
size_t count) {
64+
if (count == 0) // special case where there's a dim value of 0 in the shape
65+
return;
66+
67+
int blocksPerGrid = static_cast<int>(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
68+
CUDA_LONG N = static_cast<CUDA_LONG>(count);
69+
_UnaryElementWise<InT, OutT, FuncT, GridDim::maxThreadsPerBlock, GridDim::maxElementsPerThread>
70+
<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
71+
input_data,
72+
output_data,
73+
func,
74+
N);
75+
}
76+
77+
} // namespace cuda
78+
} // namespace onnxruntime
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include <cuda_runtime.h>
5+
#include "cu_inc/unary_elementwise_impl.cuh"
6+
7+
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
8+
#include "cuda_fp8.h"
9+
#endif
10+
#include <cuda_fp16.h>
11+
12+
namespace onnxruntime {
13+
14+
namespace cuda {
15+
16+
// the postfix of means the types supported by the op:
17+
// B: uint8_t
18+
// W: uint16_t
19+
// U: uint32_t
20+
// Z: uint64_t
21+
// C: int8_t
22+
// S: int16_t
23+
// I: int32_t
24+
// L: int64_t
25+
// H: float16
26+
// F: float
27+
// D: double
28+
// O: bool
29+
// X: BFloat16
30+
31+
// When casting, half needs to be converted via float type from most other types
32+
template <typename T>
33+
struct ViaTypeMap {
34+
typedef T ViaT;
35+
};
36+
37+
template <>
38+
struct ViaTypeMap<half> {
39+
typedef float ViaT;
40+
};
41+
42+
template <typename InT, typename OutT>
43+
struct OP_Cast {
44+
__device__ __inline__ OutT operator()(const InT& a) const {
45+
const bool any_float16 = std::is_same<half, InT>::value || std::is_same<half, OutT>::value;
46+
typedef typename std::conditional<any_float16, half, OutT>::type T;
47+
typedef typename ViaTypeMap<T>::ViaT ViaT;
48+
return (OutT)((ViaT)a);
49+
}
50+
};
51+
52+
#define IMPL_CAST_IMPL(InT, OutT) \
53+
void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \
54+
UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast<InT, OutT>(), count); \
55+
}
56+
57+
#define IMPL_CAST_IMPL_THROW(InT, OutT) \
58+
void Explicit_Impl_Cast(cudaStream_t /*stream*/, const InT* /*input_data*/, OutT* /*output_data*/, \
59+
size_t /*count*/) { \
60+
ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \
61+
}
62+
63+
#define IMPL_CAST_IMPL_FROM(T) \
64+
IMPL_CAST_IMPL(T, half) \
65+
IMPL_CAST_IMPL(T, float) \
66+
IMPL_CAST_IMPL(T, double) \
67+
IMPL_CAST_IMPL(T, int8_t) \
68+
IMPL_CAST_IMPL(T, int16_t) \
69+
IMPL_CAST_IMPL(T, int32_t) \
70+
IMPL_CAST_IMPL(T, int64_t) \
71+
IMPL_CAST_IMPL(T, uint8_t) \
72+
IMPL_CAST_IMPL(T, uint16_t) \
73+
IMPL_CAST_IMPL(T, uint32_t) \
74+
IMPL_CAST_IMPL(T, uint64_t) \
75+
IMPL_CAST_IMPL(T, bool) \
76+
//IMPL_CAST_IMPL(T, BFloat16)
77+
78+
IMPL_CAST_IMPL_FROM(half)
79+
IMPL_CAST_IMPL_FROM(float)
80+
IMPL_CAST_IMPL_FROM(double)
81+
IMPL_CAST_IMPL_FROM(int8_t)
82+
IMPL_CAST_IMPL_FROM(int16_t)
83+
IMPL_CAST_IMPL_FROM(int32_t)
84+
IMPL_CAST_IMPL_FROM(int64_t)
85+
IMPL_CAST_IMPL_FROM(uint8_t)
86+
IMPL_CAST_IMPL_FROM(uint16_t)
87+
IMPL_CAST_IMPL_FROM(uint32_t)
88+
IMPL_CAST_IMPL_FROM(uint64_t)
89+
IMPL_CAST_IMPL_FROM(bool)
90+
//IMPL_CAST_IMPL_FROM(BFloat16)
91+
92+
} // namespace cuda
93+
} // namespace onnxruntime
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <stdint.h>
7+
#include <cuda_fp16.h>
8+
#include <cuda_runtime.h>
9+
10+
namespace onnxruntime {
11+
namespace cuda {
12+
13+
// Cast
14+
15+
#define DECL_IMPL_CAST(InT, OutT) \
16+
void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count);
17+
18+
#define DECL_IMPL_CAST_FROM(T) \
19+
DECL_IMPL_CAST(T, half) \
20+
DECL_IMPL_CAST(T, float) \
21+
DECL_IMPL_CAST(T, double) \
22+
DECL_IMPL_CAST(T, int8_t) \
23+
DECL_IMPL_CAST(T, int16_t) \
24+
DECL_IMPL_CAST(T, int32_t) \
25+
DECL_IMPL_CAST(T, int64_t) \
26+
DECL_IMPL_CAST(T, uint8_t) \
27+
DECL_IMPL_CAST(T, uint16_t) \
28+
DECL_IMPL_CAST(T, uint32_t) \
29+
DECL_IMPL_CAST(T, uint64_t) \
30+
DECL_IMPL_CAST(T, bool) \
31+
//DECL_IMPL_CAST(T, BFloat16)
32+
33+
DECL_IMPL_CAST_FROM(half)
34+
DECL_IMPL_CAST_FROM(float)
35+
DECL_IMPL_CAST_FROM(double)
36+
DECL_IMPL_CAST_FROM(int8_t)
37+
DECL_IMPL_CAST_FROM(int16_t)
38+
DECL_IMPL_CAST_FROM(int32_t)
39+
DECL_IMPL_CAST_FROM(int64_t)
40+
DECL_IMPL_CAST_FROM(uint8_t)
41+
DECL_IMPL_CAST_FROM(uint16_t)
42+
DECL_IMPL_CAST_FROM(uint32_t)
43+
DECL_IMPL_CAST_FROM(uint64_t)
44+
DECL_IMPL_CAST_FROM(bool)
45+
//DECL_IMPL_CAST_FROM(BFloat16)
46+
47+
template <typename InT, typename OutT>
48+
void Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) {
49+
Explicit_Impl_Cast(stream, input_data, output_data, count);
50+
}
51+
52+
} // namespace cuda
53+
54+
} // namespace onnxruntime
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#define ORT_API_MANUAL_INIT
2+
#include "onnxruntime_cxx_api.h"
3+
#undef ORT_API_MANUAL_INIT
4+
5+
#include <gsl/gsl>
6+
#include <cassert>
7+
#include <cstring>
8+
#include <memory>
9+
#include <string>
10+
#include <unordered_map>
11+
#include <vector>
12+

0 commit comments

Comments
 (0)