-
Notifications
You must be signed in to change notification settings - Fork 404
[Plugin TRT EP] Add MemcpyToHost and MemcpyFromHost kernel implementations #557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
162 changes: 162 additions & 0 deletions
162
plugin_execution_providers/tensorrt/src/kernels/memcpy.cc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "utils.h" | ||
| #include "memcpy.h" | ||
| #include <cuda_runtime.h> | ||
|
|
||
| namespace trt_ep { | ||
|
|
||
| template <typename T> | ||
| OrtStatus* MemcpyKernelBase::CreateImpl(const OrtKernelInfo* info, void* state, | ||
| /*out*/ OrtKernelImpl*& kernel) noexcept { | ||
| try { | ||
| auto p = std::make_unique<T>(info, state, typename T::PrivateTag{}); | ||
| kernel = p.release(); | ||
| return nullptr; | ||
| } catch (const Ort::Exception& ex) { | ||
| Ort::Status status(ex); | ||
| return status.release(); | ||
| } catch (const std::exception& ex) { | ||
| Ort::Status status(ex.what(), ORT_EP_FAIL); | ||
| return status.release(); | ||
| } catch (...) { | ||
| Ort::Status status("Unknown exception in MemcpyKernelBase::Create", ORT_EP_FAIL); | ||
| return status.release(); | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| static void MemcpyKernelBase::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { | ||
| delete static_cast<T*>(this_ptr); | ||
| } | ||
|
|
||
| OrtStatus* MemcpyFromHost::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept { | ||
| try { | ||
| const OrtApi& ort_api = Ort::GetApi(); | ||
| const OrtValue* input_tensor = nullptr; | ||
| RETURN_IF_ERROR(ort_api.KernelContext_GetInput(kernel_ctx, 0, &input_tensor)); | ||
|
|
||
| // Get tensor shape and type | ||
| OrtTensorTypeAndShapeInfo* tensor_info = nullptr; | ||
| RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(input_tensor, &tensor_info)); | ||
|
|
||
| size_t element_count = 0; | ||
| RETURN_IF_ERROR(ort_api.GetTensorShapeElementCount(tensor_info, &element_count)); | ||
|
|
||
| ONNXTensorElementDataType element_type; | ||
| RETURN_IF_ERROR(ort_api.GetTensorElementType(tensor_info, &element_type)); | ||
|
|
||
| size_t num_dims = 0; | ||
| RETURN_IF_ERROR(ort_api.GetDimensionsCount(tensor_info, &num_dims)); | ||
|
|
||
| std::vector<int64_t> dims(num_dims); | ||
| RETURN_IF_ERROR(ort_api.GetDimensions(tensor_info, dims.data(), num_dims)); | ||
| ort_api.ReleaseTensorTypeAndShapeInfo(tensor_info); | ||
|
|
||
| // Get output tensor | ||
| OrtValue* output_tensor = nullptr; | ||
| RETURN_IF_ERROR(ort_api.KernelContext_GetOutput(kernel_ctx, 0, dims.data(), num_dims, &output_tensor)); | ||
|
|
||
| // Get data pointers | ||
| const void* input_data = nullptr; | ||
| void* output_data = nullptr; | ||
| RETURN_IF_ERROR(ort_api.GetTensorData(input_tensor, &input_data)); | ||
| RETURN_IF_ERROR(ort_api.GetTensorMutableData(output_tensor, &output_data)); | ||
|
|
||
| // Calculate size in bytes | ||
| size_t bytes = 0; | ||
| RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(input_tensor, &bytes)); | ||
|
|
||
| // Get CUDA stream from kernel context | ||
| void* cuda_stream = nullptr; | ||
| RETURN_IF_ERROR(ort_api.KernelContext_GetGPUComputeStream(kernel_ctx, &cuda_stream)); | ||
| cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream); | ||
|
|
||
| // Copy from host (CPU) to device (GPU) asynchronously | ||
| cudaError_t cuda_err = cudaMemcpyAsync(output_data, input_data, bytes, cudaMemcpyHostToDevice, stream); | ||
| if (cuda_err != cudaSuccess) { | ||
| return ort_api.CreateStatus(ORT_EP_FAIL, cudaGetErrorString(cuda_err)); | ||
| } | ||
|
|
||
| return nullptr; | ||
| } catch (const Ort::Exception& ex) { | ||
| Ort::Status status(ex); | ||
| return status.release(); | ||
| } catch (const std::exception& ex) { | ||
| Ort::Status status(ex.what(), ORT_EP_FAIL); | ||
| return status.release(); | ||
| } | ||
| } | ||
|
|
||
| OrtStatus* MemcpyToHost::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept { | ||
| try { | ||
| const OrtApi& ort_api = Ort::GetApi(); | ||
| const OrtValue* input_tensor = nullptr; | ||
| RETURN_IF_ERROR(ort_api.KernelContext_GetInput(kernel_ctx, 0, &input_tensor)); | ||
|
|
||
| // Get tensor shape and type | ||
| OrtTensorTypeAndShapeInfo* tensor_info = nullptr; | ||
| RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(input_tensor, &tensor_info)); | ||
|
|
||
| size_t num_dims = 0; | ||
| RETURN_IF_ERROR(ort_api.GetDimensionsCount(tensor_info, &num_dims)); | ||
|
|
||
| std::vector<int64_t> dims(num_dims); | ||
| RETURN_IF_ERROR(ort_api.GetDimensions(tensor_info, dims.data(), num_dims)); | ||
| ort_api.ReleaseTensorTypeAndShapeInfo(tensor_info); | ||
|
|
||
| // Get output tensor | ||
| OrtValue* output_tensor = nullptr; | ||
| RETURN_IF_ERROR(ort_api.KernelContext_GetOutput(kernel_ctx, 0, dims.data(), num_dims, &output_tensor)); | ||
|
|
||
| // Get data pointers | ||
| const void* input_data = nullptr; | ||
| void* output_data = nullptr; | ||
| RETURN_IF_ERROR(ort_api.GetTensorData(input_tensor, &input_data)); | ||
| RETURN_IF_ERROR(ort_api.GetTensorMutableData(output_tensor, &output_data)); | ||
|
|
||
| // Calculate size in bytes | ||
| size_t bytes = 0; | ||
| RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(input_tensor, &bytes)); | ||
|
|
||
| // Get CUDA stream from kernel context | ||
| void* cuda_stream = nullptr; | ||
| RETURN_IF_ERROR(ort_api.KernelContext_GetGPUComputeStream(kernel_ctx, &cuda_stream)); | ||
| cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream); | ||
|
|
||
| // Copy from device (GPU) to host (CPU) asynchronously | ||
| cudaError_t cuda_err = cudaMemcpyAsync(output_data, input_data, bytes, cudaMemcpyDeviceToHost, stream); | ||
| if (cuda_err != cudaSuccess) { | ||
| return ort_api.CreateStatus(ORT_EP_FAIL, cudaGetErrorString(cuda_err)); | ||
| } | ||
|
|
||
| return nullptr; | ||
| } catch (const Ort::Exception& ex) { | ||
| Ort::Status status(ex); | ||
| return status.release(); | ||
| } catch (const std::exception& ex) { | ||
| Ort::Status status(ex.what(), ORT_EP_FAIL); | ||
| return status.release(); | ||
| } | ||
| } | ||
|
|
||
| ONNX_OPERATOR_KERNEL_EX( | ||
| MemcpyFromHost, | ||
| kOnnxDomain, | ||
| /*version*/ 1, // Equivalent to start_version: 14, end_version: 14 (inclusive) | ||
| (Ort::KernelDefBuilder() | ||
| .SetInputMemType(0, OrtMemType::OrtMemTypeCPUInput) | ||
| .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), | ||
| MemcpyFromHost) | ||
|
|
||
| ONNX_OPERATOR_KERNEL_EX( | ||
| MemcpyToHost, | ||
| kOnnxDomain, | ||
| /*version*/ 1, // Equivalent to start_version: 14, end_version: 14 (inclusive) | ||
| (Ort::KernelDefBuilder() | ||
| .SetOutputMemType(0, OrtMemType::OrtMemTypeCPUOutput) | ||
| .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), | ||
| MemcpyToHost) | ||
|
|
||
| } // namespace trt_ep |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| struct OrtKernelImpl; | ||
| struct OrtKernelInfo; | ||
| struct OrtKernelContext; | ||
| struct OrtStatus; | ||
|
|
||
| namespace trt_ep { | ||
|
|
||
| struct MemcpyKernelBase : public OrtKernelImpl { | ||
| // Base class for MemcpyFromHost and MemcpyToHost to share common code. | ||
| protected: | ||
| MemcpyKernelBase(const OrtKernelInfo* info, void* state) : OrtKernelImpl {}, info_(info), state_(state) {} | ||
|
|
||
| template <typename T> | ||
| static OrtStatus* CreateImpl(const OrtKernelInfo* info, void* state, /*out*/ OrtKernelImpl*& kernel) noexcept; | ||
|
|
||
| template <typename T> | ||
| static void ReleaseImpl(OrtKernelImpl* this_ptr) noexcept; | ||
|
|
||
| const OrtKernelInfo* info_; | ||
| void* state_; // Custom state passed from OrtEp | ||
| }; | ||
|
|
||
| struct MemcpyFromHost : public MemcpyKernelBase { | ||
| private: | ||
| struct PrivateTag {}; // Used to prevent use of public constructor (use static MemcpyFromHost::Create()) | ||
| // Need to make the constructor public for std::make_unique(). | ||
|
|
||
| // Allow base template helper to access PrivateTag | ||
| friend struct MemcpyKernelBase; | ||
|
|
||
| public: | ||
| MemcpyFromHost(const OrtKernelInfo* info, void* state, PrivateTag) : MemcpyKernelBase(info, state) { | ||
| ort_version_supported = ORT_API_VERSION; | ||
| Compute = ComputeImpl; | ||
| Release = ReleaseImpl; | ||
| }; | ||
|
|
||
| static OrtStatus* Create(const OrtKernelInfo* info, void* state, | ||
| /*out*/ OrtKernelImpl*& kernel) noexcept { | ||
| return CreateImpl<MemcpyFromHost>(info, state, kernel); | ||
| } | ||
|
|
||
| static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; | ||
|
|
||
| static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { | ||
| MemcpyKernelBase::ReleaseImpl<MemcpyFromHost>(this_ptr); | ||
| }; | ||
| }; | ||
|
|
||
| struct MemcpyToHost : public MemcpyKernelBase { | ||
| private: | ||
| struct PrivateTag {}; // Used to prevent use of public constructor (use static MemcpyFromHost::Create()) | ||
| // Need to make the constructor public for std::make_unique(). | ||
|
|
||
| // Allow base template helper to access PrivateTag | ||
| friend struct MemcpyKernelBase; | ||
|
|
||
| public: | ||
| MemcpyToHost(const OrtKernelInfo* info, void* state, PrivateTag) : MemcpyKernelBase(info, state) { | ||
| ort_version_supported = ORT_API_VERSION; | ||
| Compute = ComputeImpl; | ||
| Release = ReleaseImpl; | ||
| }; | ||
|
|
||
| static OrtStatus* Create(const OrtKernelInfo* info, void* state, | ||
| /*out*/ OrtKernelImpl*& kernel) noexcept { | ||
| return CreateImpl<MemcpyToHost>(info, state, kernel); | ||
| } | ||
|
|
||
| static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; | ||
|
|
||
| static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { | ||
| MemcpyKernelBase::ReleaseImpl<MemcpyToHost>(this_ptr); | ||
| }; | ||
| }; | ||
|
|
||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| #pragma once | ||
|
|
||
| #include "ep_utils.h" | ||
|
|
||
| namespace trt_ep { | ||
|
|
||
| /// <summary> | ||
| /// Gets an OrtDataType for a tensor type. Throws on error. | ||
| /// </summary> | ||
| /// <param name="elem_type"></param> | ||
| /// <returns></returns> | ||
| inline const OrtDataType* GetTensorType(ONNXTensorElementDataType elem_type) { | ||
| const OrtEpApi& ep_api = Ort::GetEpApi(); | ||
| const OrtDataType* result = nullptr; | ||
|
|
||
| Ort::ThrowOnError(ep_api.GetTensorDataType(elem_type, &result)); | ||
| return result; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Contains information to create a kernel: kernel definition, creation function + state. | ||
| /// </summary> | ||
| struct KernelCreateInfo { | ||
| KernelCreateInfo() = default; | ||
| KernelCreateInfo(Ort::KernelDef def, OrtKernelCreateFunc func, void* state) | ||
| : kernel_def{std::move(def)}, kernel_create_func{func}, kernel_create_func_state{state} {} | ||
|
|
||
| Ort::KernelDef kernel_def{nullptr}; | ||
| OrtKernelCreateFunc kernel_create_func = nullptr; | ||
| void* kernel_create_func_state = nullptr; | ||
| }; | ||
|
|
||
| using BuildKernelCreateInfoFn = OrtStatus* (*)(const char*, void*, KernelCreateInfo*); | ||
|
|
||
| template <typename T> | ||
| OrtStatus* BuildKernelCreateInfo(const char* ep_name, void* create_func_state, /*out*/ KernelCreateInfo* result); | ||
|
|
||
| template <> | ||
| inline OrtStatus* BuildKernelCreateInfo<void>(const char* /*ep_name*/, void* /*create_func_state*/, | ||
| /*out*/ KernelCreateInfo* result) { | ||
| result->kernel_def = Ort::KernelDef{nullptr}; | ||
| result->kernel_create_func = nullptr; | ||
| result->kernel_create_func_state = nullptr; | ||
| return nullptr; | ||
| } | ||
|
|
||
| static constexpr const char* kOnnxDomain = ""; | ||
|
|
||
| // Naming convention for operator kernel classes with a start and end version range. | ||
| #define ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(domain, startver, endver, name) \ | ||
| example_ep_##name##_##domain##_ver##startver##_##endver | ||
|
|
||
| // Naming convention for operator kernel classes for a single version | ||
| #define ONNX_OPERATOR_KERNEL_CLASS_NAME(domain, version, name) \ | ||
| ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(domain, version, version, name) | ||
|
|
||
| // Defines a function of type BuildKernelCreateInfoFn for a kernel implementation with a start and end version range. | ||
| #define ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, startver, endver, builder, kernel_class) \ | ||
| class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(domain, startver, endver, name); \ | ||
| template <> \ | ||
| OrtStatus* \ | ||
| BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(domain, startver, endver, name)>( \ | ||
| const char* ep_name, \ | ||
| void* create_kernel_state, \ | ||
| KernelCreateInfo* result) { \ | ||
| try { \ | ||
| Ort::KernelDef kernel_def = builder.SetOperatorType(#name) \ | ||
| .SetDomain(domain) \ | ||
| .SetSinceVersion(startver, endver) \ | ||
| .SetExecutionProvider(ep_name) \ | ||
| .Build(); \ | ||
| \ | ||
| auto kernel_create_func = [](void* state, const OrtKernelInfo* info, \ | ||
| OrtKernelImpl** kernel_out) noexcept -> OrtStatus* { \ | ||
| RETURN_IF(kernel_out == nullptr, \ | ||
| "OrtKernelCreateFunc received a NULL kernel_out argument"); \ | ||
| \ | ||
| *kernel_out = nullptr; \ | ||
| RETURN_IF_ERROR(kernel_class::Create(info, state, *kernel_out)); \ | ||
| return nullptr; \ | ||
| }; \ | ||
| \ | ||
| *result = KernelCreateInfo(std::move(kernel_def), kernel_create_func, create_kernel_state); \ | ||
| } catch (const Ort::Exception& ex) { \ | ||
| Ort::Status status(ex); \ | ||
| return status.release(); \ | ||
| } catch (const std::exception& ex) { \ | ||
| Ort::Status status(ex.what(), ORT_EP_FAIL); \ | ||
| return status.release(); \ | ||
| } \ | ||
| return nullptr; \ | ||
| } | ||
|
|
||
| // Defines a function of type BuildKernelCreateInfoFn for a kernel implementation with a start version. | ||
| #define ONNX_OPERATOR_KERNEL_EX(name, domain, version, builder, kernel_class) \ | ||
| ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, version, version, builder, kernel_class) | ||
| } | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: formatting