Skip to content

Commit 6346cdd

Browse files
authored
Add ShouldConvertDataLayoutForOp() API to allow EPs to customize layout sensitive ops (microsoft#25147)
### Description <!-- Describe your changes. --> Add `IExecutionProvider::ShouldConvertDataLayoutForOp()` to allow EPs to customize layout sensitive ops. Move existing hardcoded EP-specific logic out of layout transformer code. Add `OrtEp::ShouldConvertDataLayoutForOp` to ABI EP API to allow similar customization by plugin EPs. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Enable layout sensitive op customization through internal EP interface and the ABI EP API.
1 parent 6c4f2ff commit 6346cdd

17 files changed

+313
-98
lines changed

include/onnxruntime/core/framework/execution_provider.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
#ifndef SHARED_PROVIDER
77
#include <memory>
8+
#include <optional>
9+
#include <string_view>
810
#include <unordered_map>
911
#include <unordered_set>
1012

@@ -63,6 +65,9 @@ using RunOptions = ::OrtRunOptions;
6365
enum class DataLayout {
6466
NCHW,
6567
NHWC,
68+
69+
// NCHW is the default ONNX standard data layout. So default to it.
70+
Default = NCHW,
6671
};
6772

6873
class IExecutionProvider {
@@ -323,9 +328,21 @@ class IExecutionProvider {
323328
}
324329

325330
virtual DataLayout GetPreferredLayout() const {
326-
// NCHW is the default ONNX standard data layout. So default to it.
327331
// EPs which prefer a different layout should override to return their preferred layout.
328-
return DataLayout::NCHW;
332+
return DataLayout::Default;
333+
}
334+
335+
/**
336+
Given an op with domain `domain` and type `op_type`, determine whether an associated node's data layout should be
337+
converted to `target_data_layout`.
338+
If the EP prefers a non-default data layout (see `GetPreferredLayout()`), this function will be called during
339+
layout transformation with `target_data_layout` set to the EP's preferred data layout.
340+
A return value of `std::nullopt` indicates that this decision is left to ORT.
341+
*/
342+
virtual std::optional<bool> ShouldConvertDataLayoutForOp(std::string_view /*domain*/,
343+
std::string_view /*op_type*/,
344+
DataLayout /*target_data_layout*/) const {
345+
return std::nullopt;
329346
}
330347

331348
virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& /*stream_handle_registry*/, AllocatorMap&) const {}

include/onnxruntime/core/session/onnxruntime_ep_c_api.h

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,18 @@ struct OrtEpApi {
299299
};
300300

301301
/**
302-
* \brief The data layout type that is preferred by an EP.
302+
* \brief The data layout type.
303+
*
304+
* EPs may specify a preferred data layout type. ORT's default layout type is OrtEpDataLayout_NCHW, or
305+
* OrtEpDataLayout_Default.
306+
*
303307
* \since Version 1.23.
304308
*/
305309
typedef enum OrtEpDataLayout {
306310
OrtEpDataLayout_NCHW = 0,
307311
OrtEpDataLayout_NHWC,
312+
313+
OrtEpDataLayout_Default = OrtEpDataLayout_NCHW,
308314
} OrtEpDataLayout;
309315

310316
/**
@@ -420,6 +426,34 @@ struct OrtEp {
420426
OrtStatus*(ORT_API_CALL* GetPreferredDataLayout)(_In_ OrtEp* this_ptr,
421427
_Out_ OrtEpDataLayout* preferred_data_layout);
422428

429+
/** \brief Given an op with domain `domain` and type `op_type`, determine whether an associated node's data layout
430+
* should be converted to `target_data_layout`.
431+
* If the EP prefers a non-default data layout (see `GetPreferredDataLayout()`), this function will be called
432+
* during layout transformation with `target_data_layout` set to the EP's preferred data layout.
433+
*
434+
* \note Implementation of this function is optional.
435+
* If an EP prefers a non-default data layout, it may implement this to customize the specific op data layout
436+
* preferences at a finer granularity.
437+
*
438+
* \param[in] this_ptr The OrtEp instance.
439+
* \param[in] domain The op domain. An empty string means the ONNX domain.
440+
* \param[in] op_type The op type.
441+
* \param[in] target_data_layout The target data layout.
442+
* \param[out] should_convert Whether the associated node's data layout should be converted to `target_data_layout`.
443+
* If greater than 0, convert.
444+
* If 0, don't convert.
445+
* Otherwise, if less than 0, leave the decision to ORT.
446+
*
447+
* \snippet{doc} snippets.dox OrtStatus Return Value
448+
*
449+
* \since Version 1.23.
450+
*/
451+
OrtStatus*(ORT_API_CALL* ShouldConvertDataLayoutForOp)(_In_ OrtEp* this_ptr,
452+
_In_z_ const char* domain,
453+
_In_z_ const char* op_type,
454+
_In_ OrtEpDataLayout target_data_layout,
455+
_Outptr_ int* should_convert);
456+
423457
/** \brief Set dynamic options on this EP.
424458
*
425459
* Dynamic options can be set by the user at any time after session creation with `OrtApi::SetEpDynamicOptions()`.

onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc

Lines changed: 29 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -30,93 +30,29 @@ CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const a
3030
return OrtEPCostCheck(graph, node, perm, outputs_leading_to_transpose);
3131
}
3232

33-
#if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS
34-
// TODO(mtavenrath) generate list from registered kernels using nhwc domain
35-
const std::unordered_set<std::string_view>& GetCUDALayoutSensitiveOps() {
36-
static std::unordered_set<std::string_view> cuda_nhwc_ops = []() {
37-
return std::unordered_set<std::string_view>{
38-
"BatchNormalization",
39-
"Conv",
40-
"ConvTranspose",
41-
"GlobalMaxPool",
42-
"MaxPool",
43-
"GlobalAveragePool",
44-
"AveragePool",
45-
"GridSample",
46-
"DepthToSpace",
47-
"SpaceToDepth",
48-
"LRN"};
49-
}();
50-
return cuda_nhwc_ops;
51-
}
52-
#endif
53-
5433
/// <summary>
5534
/// Default function for checking if a node should have its layout changed. Allows EP specific adjustments to the
5635
/// default set of layout sensitive operators if required.
57-
///
58-
/// Longer term, if required, the EP API could allow the EP to provide a delegate to plugin EP specific logic so we
59-
/// don't hardcode it here.
6036
/// </summary>
37+
/// <param name="execution_provider">The EP instance.</param>
6138
/// <param name="node">Node to check</param>
6239
/// <returns>true if the node should have its layout converted to NHWC.</returns>
63-
bool ConvertNodeLayout(const api::NodeRef& node) {
40+
bool ShouldConvertNodeLayoutToNhwc(const IExecutionProvider& execution_provider, const api::NodeRef& node) {
6441
// skip if op is not an ONNX or contrib op
65-
auto domain = node.Domain();
42+
const auto domain = node.Domain();
6643
if (domain != kOnnxDomain && domain != kMSDomain) {
6744
return false;
6845
}
6946

70-
const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps();
71-
72-
// handle special cases
73-
#if defined(USE_JSEP)
74-
// TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed
75-
if (node.GetExecutionProviderType() == kJsExecutionProvider) {
76-
if (node.OpType() == "Resize") {
77-
// leave Resize as-is pending bugfix for NHWC implementation. this means the node will remain in the ONNX domain
78-
// with the original input layout.
79-
return false;
80-
}
47+
const auto op_type = node.OpType();
48+
if (auto should_convert_from_ep = execution_provider.ShouldConvertDataLayoutForOp(domain, op_type, DataLayout::NHWC);
49+
should_convert_from_ep.has_value()) {
50+
return *should_convert_from_ep;
8151
}
82-
#endif
8352

84-
// NHWC for Resize operator is not implemented on kWebGpuExecutionProvider
85-
#if defined(USE_WEBGPU)
86-
if (node.GetExecutionProviderType() == kWebGpuExecutionProvider) {
87-
if (node.OpType() == "Resize") {
88-
return false;
89-
}
90-
}
91-
#endif
92-
93-
// TODO: We don't need to check USE_CUDA || USE_CUDA_PROVIDER_INTERFACE in this function because we're already
94-
// checking if the node is assigned to the desired EP (e.g., CUDA EP). We should only need to check
95-
// ENABLE_CUDA_NHWC_OPS.
96-
#if (defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE)) && ENABLE_CUDA_NHWC_OPS
97-
if (node.GetExecutionProviderType() == kCudaExecutionProvider) {
98-
if (layout_sensitive_ops.count(node.OpType())) {
99-
const auto& cuda_nhwc_ops = GetCUDALayoutSensitiveOps();
100-
if (!cuda_nhwc_ops.count(node.OpType())) {
101-
return false;
102-
}
103-
}
104-
}
105-
#endif
106-
107-
// TODO: We don't really need EP pre-processor macros in this function because we're already checking if the
108-
// node is assigned to the desired EP (e.g., QNN EP). There's nothing about this code that absolutely requires
109-
// conditional compilation.
110-
#if defined(USE_QNN) || defined(USE_QNN_PROVIDER_INTERFACE)
111-
if (node.GetExecutionProviderType() == kQnnExecutionProvider) {
112-
if (node.OpType() == "Upsample") {
113-
// Upsample is translated to QNN's Resize, which requires the NHWC layout for processing.
114-
return true;
115-
}
116-
}
117-
#endif
118-
119-
return layout_sensitive_ops.count(node.OpType()) != 0;
53+
const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps();
54+
const auto op_identifier = MakeORTLayoutSensitiveOpId(domain, op_type);
55+
return layout_sensitive_ops.find(op_identifier) != layout_sensitive_ops.end();
12056
}
12157
} // namespace
12258

@@ -126,25 +62,37 @@ bool ConvertNodeLayout(const api::NodeRef& node) {
12662
// Once all the layout sensitive ops requested by the EP are wrapped the transpose optimizer will attempt to remove
12763
// as many of the layout transposes as possible.
12864
const std::unordered_set<std::string_view>& GetORTLayoutSensitiveOps() {
129-
static std::unordered_set<std::string_view> ort_layout_sensitive_ops = []() {
130-
const auto& layout_sensitive_ops = onnx_transpose_optimization::GetLayoutSensitiveOps();
65+
static const std::unordered_set<std::string_view> ort_layout_sensitive_ops = []() {
66+
const auto& layout_sensitive_onnx_ops = onnx_transpose_optimization::GetLayoutSensitiveOps();
67+
68+
// Define a static local string array so we can refer to the elements with string_views.
69+
static const std::string layout_sensitive_contrib_ops[]{
70+
MakeORTLayoutSensitiveOpId(kMSDomain, "FusedConv"),
71+
MakeORTLayoutSensitiveOpId(kMSDomain, "GridSample"),
72+
MakeORTLayoutSensitiveOpId(kMSDomain, "QLinearAveragePool"),
73+
MakeORTLayoutSensitiveOpId(kMSDomain, "QLinearGlobalAveragePool"),
74+
};
75+
13176
std::unordered_set<std::string_view> ort_specific_ops =
13277
{
133-
"FusedConv",
134-
"QLinearAveragePool",
135-
"QLinearGlobalAveragePool",
13678
// Whilst the ONNX spec doesn't specify a layout for Resize, we treat it as layout sensitive by default
13779
// as EPs tend to only support one layout.
13880
"Resize",
13981
};
14082

141-
ort_specific_ops.insert(layout_sensitive_ops.cbegin(), layout_sensitive_ops.cend());
83+
ort_specific_ops.insert(std::begin(layout_sensitive_onnx_ops), std::end(layout_sensitive_onnx_ops));
84+
ort_specific_ops.insert(std::begin(layout_sensitive_contrib_ops), std::end(layout_sensitive_contrib_ops));
14285
return ort_specific_ops;
14386
}();
14487

14588
return ort_layout_sensitive_ops;
14689
}
14790

91+
// "op_type" if from ONNX domain, "domain:op_type" otherwise.
92+
std::string MakeORTLayoutSensitiveOpId(std::string_view domain, std::string_view op_type) {
93+
return (domain == kOnnxDomain) ? std::string(op_type) : MakeString(domain, ":", op_type);
94+
}
95+
14896
Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvider& execution_provider,
14997
AllocatorPtr cpu_allocator,
15098
const DebugGraphFn& debug_graph_fn) {
@@ -159,7 +107,7 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid
159107
continue;
160108
}
161109

162-
if (ConvertNodeLayout(*node)) {
110+
if (ShouldConvertNodeLayoutToNhwc(execution_provider, *node)) {
163111
// domain kMSInternalNHWCDomain uses OpType "Conv" for both Conv and FusedConv.
164112
// So, change the OpType to "Conv" for FusedConv.
165113
std::string_view op_type = node->OpType() == "FusedConv" ? "Conv" : node->OpType();

onnxruntime/core/optimizer/layout_transformation/layout_transformation.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,19 @@ bool IsSupportedOpset(const Graph& graph);
6868
/// Gets a list of layout sensitive ops for ORT. This list contains ONNX standard defined
6969
/// layout sensitive ops + contrib ops + ops which are not layout sensitive but are treated as
7070
/// layout sensitive by ORT EPs (example Resize).
71+
///
72+
/// Note: The format of the returned op identifiers is "<op type>" for ops in the ONNX domain and
73+
/// "<domain>:<op type>" for ops in other domains. `MakeORTLayoutSensitiveOpId()` can be used to
74+
/// create an op identifier with this format.
7175
/// </summary>
72-
/// <returns>unordered set of op_types which are layout sensitive</returns>
76+
/// <returns>set of op identifiers which are layout sensitive</returns>
7377
const std::unordered_set<std::string_view>& GetORTLayoutSensitiveOps();
7478

79+
/// <summary>
80+
/// Creates an op identifier compatible with `GetORTLayoutSensitiveOps()`.
81+
/// </summary>
82+
std::string MakeORTLayoutSensitiveOpId(std::string_view domain, std::string_view op_type);
83+
7584
/// <summary>
7685
/// Inserts transposes around op inputs/outputs. Alternatively transposes initializers or uses existing Transpose
7786
/// nodes if possible. Populates shape information on affected node inputs/outputs to reflect the change.

onnxruntime/core/providers/cuda/cuda_execution_provider.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,37 @@ DataLayout CUDAExecutionProvider::GetPreferredLayout() const {
323323
return this->IsNHWCPreferred() ? DataLayout::NHWC : DataLayout::NCHW;
324324
}
325325

326+
std::optional<bool> CUDAExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain,
327+
std::string_view node_op_type,
328+
DataLayout target_data_layout) const {
329+
#if defined(ENABLE_CUDA_NHWC_OPS)
330+
if (target_data_layout != DataLayout::NHWC) {
331+
return std::nullopt;
332+
}
333+
334+
// TODO(mtavenrath) generate list from registered kernels using nhwc domain
335+
static const std::unordered_set<std::string_view> cuda_nhwc_onnx_ops{
336+
"BatchNormalization",
337+
"Conv",
338+
"ConvTranspose",
339+
"GlobalMaxPool",
340+
"MaxPool",
341+
"GlobalAveragePool",
342+
"AveragePool",
343+
"GridSample",
344+
"DepthToSpace",
345+
"SpaceToDepth",
346+
"LRN",
347+
};
348+
349+
return (node_domain == kOnnxDomain && cuda_nhwc_onnx_ops.find(node_op_type) != cuda_nhwc_onnx_ops.end()) ||
350+
(node_domain == kMSDomain && node_op_type == "GridSample");
351+
352+
#else // defined(ENABLE_CUDA_NHWC_OPS)
353+
return std::nullopt;
354+
#endif
355+
}
356+
326357
CUDAExecutionProvider::~CUDAExecutionProvider() {
327358
// clean up thread local context caches
328359
{

onnxruntime/core/providers/cuda/cuda_execution_provider.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ class CUDAExecutionProvider : public IExecutionProvider {
3939

4040
DataLayout GetPreferredLayout() const override;
4141

42+
std::optional<bool> ShouldConvertDataLayoutForOp(std::string_view node_domain,
43+
std::string_view node_op_type,
44+
DataLayout target_data_layout) const override;
45+
4246
const void* GetExecutionHandle() const noexcept override {
4347
// The CUDA interface does not return anything interesting.
4448
return nullptr;

onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424

2525
namespace onnxruntime::cuda {
2626

27-
// When adding new supported NHWC operations make sure to also integrate them into: ConvertNodeLayout
28-
// in onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
27+
// When adding new supported NHWC operations make sure to also integrate them into
28+
// CUDAExecutionProvider::ShouldConvertDataLayoutForOp()
2929

3030
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, float, BatchNormalization);
3131
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, double, BatchNormalization);

onnxruntime/core/providers/js/js_execution_provider.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,23 @@ std::unique_ptr<onnxruntime::IExternalDataLoader> JsExecutionProvider::GetExtern
849849
return std::make_unique<js::ExternalDataLoader>();
850850
}
851851

852+
std::optional<bool> JsExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain,
853+
std::string_view node_op_type,
854+
DataLayout target_data_layout) const {
855+
if (target_data_layout != DataLayout::NHWC) {
856+
return std::nullopt;
857+
}
858+
859+
// TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed
860+
if (node_domain == kOnnxDomain && node_op_type == "Resize") {
861+
// leave Resize as-is pending bugfix for NHWC implementation. this means the node will remain in the ONNX domain
862+
// with the original input layout.
863+
return false;
864+
}
865+
866+
return std::nullopt;
867+
}
868+
852869
JsExecutionProvider::~JsExecutionProvider() {
853870
}
854871

onnxruntime/core/providers/js/js_execution_provider.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ class JsExecutionProvider : public IExecutionProvider {
5454

5555
DataLayout GetPreferredLayout() const override { return preferred_data_layout_; }
5656

57+
std::optional<bool> ShouldConvertDataLayoutForOp(std::string_view node_domain,
58+
std::string_view node_op_type,
59+
DataLayout target_data_layout) const override;
60+
5761
FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; }
5862

5963
// JSEP disallow concurrent run because actual implementation (eg. WebGPU backend) relies on global states to work,

onnxruntime/core/providers/qnn/qnn_execution_provider.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,21 @@ DataLayout QNNExecutionProvider::GetPreferredLayout() const {
10661066
return DataLayout::NHWC;
10671067
}
10681068

1069+
std::optional<bool> QNNExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain,
1070+
std::string_view node_op_type,
1071+
DataLayout target_data_layout) const {
1072+
if (target_data_layout != DataLayout::NHWC) {
1073+
return std::nullopt;
1074+
}
1075+
1076+
if (node_domain == kOnnxDomain && node_op_type == "Upsample") {
1077+
// Upsample is translated to QNN's Resize, which requires the NHWC layout for processing.
1078+
return true;
1079+
}
1080+
1081+
return std::nullopt;
1082+
}
1083+
10691084
Status QNNExecutionProvider::CreateComputeFunc(std::vector<NodeComputeInfo>& node_compute_funcs,
10701085
const logging::Logger& logger) {
10711086
NodeComputeInfo compute_info;

0 commit comments

Comments
 (0)