Skip to content

Commit 8bb3b07

Browse files
Implement experimental intermediate cross CPU EP allocation (microsoft#24371)
### Description <!-- Describe your changes. --> Onnxruntime manages a number of CPU based accelerators. I.e. those that can operate on CPU based inputs. However, several of them like `Qnn`, `Openvino` and `Vitis` may require CPU based inputs to be either aligned to 4K so they can be memory mapped or prefer to override the device with their own CPU accessible allocator. To mitigate that, we introduce a new CPU based allocator that produces 4K aligned memory. We also adjust allocation planner to override plain CPU device. When we detect a compiled CPU based EP, we adjust the device according by requesting the EP to return `OrtMemType::OrtMemTypeCPUInput`. This gives the EP an opportunity to return either GPU/NPU device or CPU device depending on the mode it is operating. We select the device with larger alignment betrween CPU default devices. We also adjust memory patterns to make sure 4K alignment is respected in the contagious buffers when appropriate. ### Motivation and Context CPU Based providers, notably accept CPU based inputs, but they have a requirement of 4K allocations, otherwise the input incurs an extra copy. This is especially noticeable with intermediate values that are produced by upstream CPU based nodes. Qnn has its own allocator when it is enabled, we make sure it is correctly advertised to the allocation planner. This PR excludes Qnn allocator usage for intermediate values due to the overhead contributed by memhandle management. Cc: @quic-ashigarg --------- Co-authored-by: edgchen1 <[email protected]>
1 parent 3a7c8b3 commit 8bb3b07

19 files changed

+342
-148
lines changed

cmake/onnxruntime_unittests.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,7 +1291,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
12911291
endif()
12921292
if (CMAKE_SYSTEM_NAME MATCHES "AIX")
12931293
list(APPEND onnxruntime_perf_test_libs onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 gtest absl_failure_signal_handler absl_examine_stack absl_flags_parse absl_flags_usage absl_flags_usage_internal)
1294-
endif()
1294+
endif()
12951295
target_link_libraries(onnxruntime_perf_test PRIVATE ${onnxruntime_perf_test_libs} Threads::Threads)
12961296
if(WIN32)
12971297
target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32)
@@ -1301,7 +1301,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
13011301
endif()
13021302
set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest")
13031303

1304-
endif()
1304+
endif()
13051305

13061306

13071307
if(onnxruntime_USE_QNN)

include/onnxruntime/core/framework/allocator.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ struct OrtArenaCfg {
4141

4242
namespace onnxruntime {
4343
constexpr const char* CPU = "Cpu";
44+
constexpr const char* CPU_ALIGNED_4K = "CpuAligned4K";
4445
constexpr const char* CUDA = "Cuda";
4546
constexpr const char* CUDA_PINNED = "CudaPinned";
4647
constexpr const char* CANN = "Cann";
@@ -57,6 +58,7 @@ constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer";
5758
constexpr const char* WEBNN_TENSOR = "WebNN_Tensor";
5859

5960
constexpr size_t kAllocAlignment = 256;
61+
constexpr const size_t kAlloc4KAlignment = 4096;
6062

6163
class IAllocator;
6264
class Stream;
@@ -270,4 +272,7 @@ using AllocatorMap = std::map<OrtDevice, AllocatorPtr>;
270272

271273
void* AllocatorDefaultAlloc(size_t size);
272274
void AllocatorDefaultFree(void* p);
275+
void* AllocatorDefaultAllocAligned(size_t size, size_t alignment);
276+
void AllocatorDefaultFreeAligned(void* p, size_t alignment);
277+
273278
} // namespace onnxruntime

include/onnxruntime/core/framework/ortdevice.h

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ struct OrtDevice {
1111
using DeviceType = int8_t;
1212
using MemoryType = int8_t;
1313
using DeviceId = int16_t;
14+
using Alignment = size_t;
1415

1516
// Pre-defined device types.
1617
static const DeviceType CPU = 0;
@@ -28,31 +29,40 @@ struct OrtDevice {
2829
static const MemoryType QNN_HTP_SHARED = 4;
2930
};
3031

31-
constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_)
32+
constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_, Alignment alignment) noexcept
3233
: device_type(device_type_),
3334
memory_type(memory_type_),
34-
device_id(device_id_) {}
35+
device_id(device_id_),
36+
alignment(alignment) {}
3537

36-
constexpr OrtDevice() : OrtDevice(CPU, MemType::DEFAULT, 0) {}
38+
constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_) noexcept
39+
: OrtDevice(device_type_, memory_type_, device_id_, 0) {}
3740

38-
DeviceType Type() const {
41+
constexpr OrtDevice() noexcept : OrtDevice(CPU, MemType::DEFAULT, 0) {}
42+
43+
DeviceType Type() const noexcept {
3944
return device_type;
4045
}
4146

42-
MemoryType MemType() const {
47+
MemoryType MemType() const noexcept {
4348
return memory_type;
4449
}
4550

46-
DeviceId Id() const {
51+
DeviceId Id() const noexcept {
4752
return device_id;
4853
}
4954

55+
Alignment GetAlignment() const noexcept {
56+
return alignment;
57+
}
58+
5059
std::string ToString() const {
5160
std::ostringstream ostr;
5261
ostr << "Device:["
5362
<< "DeviceType:" << static_cast<int>(device_type)
5463
<< " MemoryType:" << static_cast<int>(memory_type)
5564
<< " DeviceId:" << device_id
65+
<< " Alignment:" << alignment
5666
<< "]";
5767
return ostr.str();
5868
}
@@ -62,6 +72,7 @@ struct OrtDevice {
6272
auto h = std::hash<int>()(device_type);
6373
onnxruntime::HashCombine(memory_type, h);
6474
onnxruntime::HashCombine(device_id, h);
75+
onnxruntime::HashCombine(alignment, h);
6576
return h;
6677
}
6778

@@ -71,8 +82,10 @@ struct OrtDevice {
7182
return device_type < other.device_type;
7283
if (memory_type != other.memory_type)
7384
return memory_type < other.memory_type;
85+
if (device_id != other.device_id)
86+
return device_id < other.device_id;
7487

75-
return device_id < other.device_id;
88+
return alignment < other.alignment;
7689
}
7790

7891
private:
@@ -84,6 +97,9 @@ struct OrtDevice {
8497

8598
// Device index.
8699
int32_t device_id : 16;
100+
101+
// Required alignment
102+
Alignment alignment;
87103
};
88104

89105
inline bool operator==(const OrtDevice& left, const OrtDevice& other) {

onnxruntime/core/framework/allocation_planner.cc

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <sstream>
99
#include <ctime>
1010
#include <iomanip>
11+
#include <iterator>
1112
#include "core/common/exceptions.h"
1213
#include "core/common/inlined_containers.h"
1314
#include "core/common/safeint.h"
@@ -725,6 +726,25 @@ class PlannerImpl {
725726
ProcessDef(index, graph_viewer_.GetNodeArg(pair.first));
726727
}
727728

729+
// If the suggested_device is also CPU and default mem type, then
730+
// we check which one has higher alignment and use that one if it is so.
731+
// If the suggested device is CPU, but not the default mem type, then
732+
// it is a CPU accessible memory device allocator. They typically have a page aligment
733+
// so that would satisfy the alignment requirement of any other CPU consumers.
734+
// If one device is not on CPU, we default on the one that is CPU.
735+
auto determine_device = [](const OrtDevice& output_device, const OrtDevice& suggested_device) -> OrtDevice {
736+
if (output_device.Type() == OrtDevice::CPU && suggested_device.Type() == OrtDevice::CPU) {
737+
if (output_device.MemType() == OrtDevice::MemType::DEFAULT &&
738+
suggested_device.MemType() == OrtDevice::MemType::DEFAULT) {
739+
return (output_device.GetAlignment() >= suggested_device.GetAlignment()) ? output_device : suggested_device;
740+
} else {
741+
return (output_device.MemType() != OrtDevice::MemType::DEFAULT) ? output_device : suggested_device;
742+
}
743+
} else {
744+
return (output_device.Type() == OrtDevice::CPU) ? output_device : suggested_device;
745+
}
746+
};
747+
728748
InlinedHashSet<OrtValueIndex> set_node_arg_has_explicit_consumer;
729749

730750
InlinedHashMap<OrtValueIndex, const IExecutionProvider*> map_implicitly_consumed_node_arg_to_ep;
@@ -756,6 +776,7 @@ class PlannerImpl {
756776
// Add location information if applicable for the provided input def
757777
auto process_input = [&graph_inputs, &exec_provider, &p_kernel_def, &is_implicit_input,
758778
&set_node_arg_has_explicit_consumer,
779+
&determine_device,
759780
&map_implicitly_consumed_node_arg_to_ep,
760781
&set_implicitly_consumed_node_arg_has_heterogenous_ep_consumers,
761782
this](const NodeArg& input, size_t arg_idx) {
@@ -856,9 +877,12 @@ class PlannerImpl {
856877
// we have seen
857878
plan_.SetLocation(static_cast<size_t>(index), exec_provider->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault));
858879
} else {
859-
// Default the location to CPU
860-
plan_.SetLocation(static_cast<size_t>(index),
861-
execution_providers_.Get(CPU)->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault));
880+
// We want to minimize the amount of copies, so we want at least one
881+
// device to match or match both if they are CPU based.
882+
OrtDevice result = determine_device(
883+
already_seen_ep_for_node_arg->second->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault),
884+
exec_provider->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault));
885+
plan_.SetLocation(static_cast<size_t>(index), result);
862886
set_implicitly_consumed_node_arg_has_heterogenous_ep_consumers.insert(index);
863887
}
864888
}
@@ -881,7 +905,37 @@ class PlannerImpl {
881905
if (!node_output->Exists()) continue;
882906
OrtValueIndex index = Index(node_output->Name());
883907
ProcessDef(index, node_output);
884-
plan_.SetLocation(static_cast<size_t>(index), exec_provider->GetOrtDeviceByMemType(p_kernel_def->OutputMemoryType(i)));
908+
OrtDevice output_device = exec_provider->GetOrtDeviceByMemType(p_kernel_def->OutputMemoryType(i));
909+
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
910+
// Downstream nodes of certain providers may require a CPU accessible location override
911+
// to make sure the EP does not incur an unnecessary copy.
912+
// We only do it for CPU based EPs. We are not likely to encounter
913+
// non CPU devices here since they are already taken care of by using MemCpy nodes earlier.
914+
// However, we still ignore them.
915+
if (output_device.Type() == OrtDevice::CPU &&
916+
output_device.MemType() == OrtDevice::MemType::DEFAULT) {
917+
const auto& output_name = node_output->Name();
918+
const auto consumers = graph_viewer_.GetConsumerNodes(output_name);
919+
for (const auto* consumer : consumers) {
920+
if (consumer != nullptr) {
921+
const auto& ep_type = consumer->GetExecutionProviderType();
922+
auto suggested_device = execution_providers_.Get(ep_type)->GetOrtDeviceByMemType(
923+
OrtMemType::OrtMemTypeCPUInput);
924+
if (suggested_device.Type() == OrtDevice::CPU &&
925+
suggested_device.MemType() == OrtDevice::MemType::DEFAULT) {
926+
output_device = determine_device(output_device, suggested_device);
927+
} else if (suggested_device.Type() == OrtDevice::CPU) {
928+
// Edge case: there are more than one downstream nodes that suggest their own CPU accessible
929+
// memory. In that case, we can not win them all, but the chosen device would still make it run
930+
// and reduce a number of copies for some.
931+
output_device = suggested_device;
932+
break;
933+
}
934+
}
935+
}
936+
}
937+
#endif
938+
plan_.SetLocation(static_cast<size_t>(index), output_device);
885939
}
886940
}
887941
}

onnxruntime/core/framework/allocator.cc

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, siz
4141
}
4242

4343
#ifdef USE_MIMALLOC
44-
void* AllocatorDefaultAlloc(size_t size) {
45-
const size_t alignment = MlasGetPreferredBufferAlignment();
44+
void* AllocatorDefaultAllocAligned(size_t size, size_t alignment) {
4645
if (size <= 0) return nullptr;
4746
size += MLAS_SYMM_QGEMM_BUF_OVERRUN;
4847
void* p;
@@ -71,10 +70,18 @@ void AllocatorDefaultFree(void* p) {
7170
#endif
7271
}
7372

73+
void AllocatorDefaultFreeAligned(void* p, size_t alignment) {
74+
#if defined(_MSC_VER)
75+
mi_free_aligned(p, alignment);
7476
#else
75-
void* AllocatorDefaultAlloc(size_t size) {
76-
const size_t alignment = MlasGetPreferredBufferAlignment();
77-
if (size <= 0) return nullptr;
77+
mi_free(p);
78+
#endif
79+
}
80+
81+
#else
82+
83+
void* AllocatorDefaultAllocAligned(size_t size, size_t alignment) {
84+
if (size == 0) return nullptr;
7885
size += MLAS_SYMM_QGEMM_BUF_OVERRUN;
7986
void* p;
8087
#if _MSC_VER
@@ -101,14 +108,25 @@ void AllocatorDefaultFree(void* p) {
101108
#endif
102109
}
103110

111+
void AllocatorDefaultFreeAligned(void* p, size_t /* alignment */) {
112+
AllocatorDefaultFree(p);
113+
}
114+
104115
#endif // USE_MIMALLOC
105116

117+
void* AllocatorDefaultAlloc(size_t size) {
118+
const size_t alignment = MlasGetPreferredBufferAlignment();
119+
return AllocatorDefaultAllocAligned(size, alignment);
120+
}
121+
106122
void* CPUAllocator::Alloc(size_t size) {
107-
return AllocatorDefaultAlloc(size);
123+
const auto alignment = std::max(Info().device.GetAlignment(), MlasGetPreferredBufferAlignment());
124+
return AllocatorDefaultAllocAligned(size, alignment);
108125
}
109126

110127
void CPUAllocator::Free(void* p) {
111-
AllocatorDefaultFree(p);
128+
const auto alignment = std::max(Info().device.GetAlignment(), MlasGetPreferredBufferAlignment());
129+
AllocatorDefaultFreeAligned(p, alignment);
112130
}
113131

114132
void* AllocateBufferWithOptions(IAllocator& alloc, size_t size, bool use_reserve, Stream* stream, WaitNotificationFn wait_fn) {
@@ -168,6 +186,11 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA
168186
onnxruntime::QNN_HTP_SHARED, type,
169187
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, static_cast<OrtDevice::DeviceId>(id1)),
170188
id1, mem_type1);
189+
} else if (strcmp(name1, onnxruntime::CPU_ALIGNED_4K) == 0) {
190+
*out = new OrtMemoryInfo(
191+
onnxruntime::CPU_ALIGNED_4K, type,
192+
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1), onnxruntime::kAlloc4KAlignment),
193+
id1, mem_type1);
171194
} else {
172195
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported.");
173196
}

onnxruntime/core/framework/execution_frame.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,11 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va
529529
return Status(ONNXRUNTIME, FAIL, "Trying to allocate memory for unused optional inputs/outputs");
530530
}
531531

532+
// This alignment is used to properly space out individual chunks in mempatterns memory buffer.
533+
const auto alignment = std::max(location.GetAlignment(), kAllocAlignment);
534+
532535
size_t size = 0;
533-
ORT_RETURN_IF_ERROR(Tensor::CalculateTensorStorageSize(element_type, shape, kAllocAlignment, size));
536+
ORT_RETURN_IF_ERROR(Tensor::CalculateTensorStorageSize(element_type, shape, alignment, size));
534537

535538
// Lazily get the allocator only if needed.
536539
AllocatorPtr alloc = nullptr;

onnxruntime/core/providers/cpu/cpu_execution_provider.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ CPUExecutionProvider::CPUExecutionProvider(const CPUExecutionProviderInfo& info)
3232

3333
std::vector<AllocatorPtr> CPUExecutionProvider::CreatePreferredAllocators() {
3434
const bool create_arena = DoesCpuAllocatorSupportArenaUsage() ? info_.create_arena : false;
35-
AllocatorCreationInfo device_info{[](int) { return std::make_unique<CPUAllocator>(); },
36-
DEFAULT_CPU_ALLOCATOR_DEVICE_ID, create_arena};
35+
AllocatorCreationInfo device_info_cpu{[](int) { return std::make_unique<CPUAllocator>(); },
36+
DEFAULT_CPU_ALLOCATOR_DEVICE_ID, create_arena};
3737

38-
return std::vector<AllocatorPtr>{CreateAllocator(device_info)};
38+
return std::vector<AllocatorPtr>{CreateAllocator(device_info_cpu)};
3939
}
4040

4141
// Forward declarations of op kernels

onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1850,9 +1850,10 @@ Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t cont
18501850
if (did_register) {
18511851
HtpSharedMemoryAllocator::AllocationCleanUpFn unregister_mem_handle =
18521852
[&logger = *logger_,
1853+
shared_memory_address,
18531854
weak_backend_manager = weak_from_this(),
18541855
weak_context_handle_record = std::weak_ptr{context_handle_record}](
1855-
void* shared_memory_address) {
1856+
void* /* allocation_base_address */) {
18561857
// Lock QnnBackendManager shared_ptr to ensure that QNN interface is still valid.
18571858
auto backend_manager = weak_backend_manager.lock();
18581859
if (!backend_manager) {

onnxruntime/core/providers/qnn/builder/qnn_model.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ static Status BindQnnTensorMemoryToOrtValueMemory(const logging::Logger& logger,
200200
Qnn_ContextHandle_t qnn_context,
201201
Qnn_Tensor_t& qnn_tensor) {
202202
// either set qnn_tensor memHandle or clientBuf
203-
const bool uses_shared_memory = ort_value_memory_info == HtpSharedMemoryAllocator::AssociatedMemoryInfo();
203+
const static auto htp_shared_mem_info = HtpSharedMemoryAllocator::AssociatedMemoryInfo();
204+
const bool uses_shared_memory = (ort_value_memory_info.device.Type() == htp_shared_mem_info.device.Type() &&
205+
ort_value_memory_info.device.MemType() == htp_shared_mem_info.device.MemType());
204206

205207
if (!uses_shared_memory) {
206208
LOGS(logger, VERBOSE) << "Setting Qnn_Tensor_t clientBuf to ORT tensor memory.";

0 commit comments

Comments
 (0)