@@ -30,93 +30,29 @@ CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const a
3030 return OrtEPCostCheck (graph, node, perm, outputs_leading_to_transpose);
3131}
3232
33- #if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS
34- // TODO(mtavenrath) generate list from registered kernels using nhwc domain
35- const std::unordered_set<std::string_view>& GetCUDALayoutSensitiveOps () {
36- static std::unordered_set<std::string_view> cuda_nhwc_ops = []() {
37- return std::unordered_set<std::string_view>{
38- " BatchNormalization" ,
39- " Conv" ,
40- " ConvTranspose" ,
41- " GlobalMaxPool" ,
42- " MaxPool" ,
43- " GlobalAveragePool" ,
44- " AveragePool" ,
45- " GridSample" ,
46- " DepthToSpace" ,
47- " SpaceToDepth" ,
48- " LRN" };
49- }();
50- return cuda_nhwc_ops;
51- }
52- #endif
53-
5433// / <summary>
5534// / Default function for checking if a node should have its layout changed. Allows EP specific adjustments to the
5635// / default set of layout sensitive operators if required.
57- // /
58- // / Longer term, if required, the EP API could allow the EP to provide a delegate to plugin EP specific logic so we
59- // / don't hardcode it here.
6036// / </summary>
37+ // / <param name="execution_provider">The EP instance.</param>
6138// / <param name="node">Node to check</param>
6239// / <returns>true if the node should have its layout converted to NHWC.</returns>
63- bool ConvertNodeLayout ( const api::NodeRef& node) {
40+ bool ShouldConvertNodeLayoutToNhwc ( const IExecutionProvider& execution_provider, const api::NodeRef& node) {
6441 // skip if op is not an ONNX or contrib op
65- auto domain = node.Domain ();
42+ const auto domain = node.Domain ();
6643 if (domain != kOnnxDomain && domain != kMSDomain ) {
6744 return false ;
6845 }
6946
70- const auto & layout_sensitive_ops = GetORTLayoutSensitiveOps ();
71-
72- // handle special cases
73- #if defined(USE_JSEP)
74- // TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed
75- if (node.GetExecutionProviderType () == kJsExecutionProvider ) {
76- if (node.OpType () == " Resize" ) {
77- // leave Resize as-is pending bugfix for NHWC implementation. this means the node will remain in the ONNX domain
78- // with the original input layout.
79- return false ;
80- }
47+ const auto op_type = node.OpType ();
48+ if (auto should_convert_from_ep = execution_provider.ShouldConvertDataLayoutForOp (domain, op_type, DataLayout::NHWC);
49+ should_convert_from_ep.has_value ()) {
50+ return *should_convert_from_ep;
8151 }
82- #endif
8352
84- // NHWC for Resize operator is not implemented on kWebGpuExecutionProvider
85- #if defined(USE_WEBGPU)
86- if (node.GetExecutionProviderType () == kWebGpuExecutionProvider ) {
87- if (node.OpType () == " Resize" ) {
88- return false ;
89- }
90- }
91- #endif
92-
93- // TODO: We don't need to check USE_CUDA || USE_CUDA_PROVIDER_INTERFACE in this function because we're already
94- // checking if the node is assigned to the desired EP (e.g., CUDA EP). We should only need to check
95- // ENABLE_CUDA_NHWC_OPS.
96- #if (defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE)) && ENABLE_CUDA_NHWC_OPS
97- if (node.GetExecutionProviderType () == kCudaExecutionProvider ) {
98- if (layout_sensitive_ops.count (node.OpType ())) {
99- const auto & cuda_nhwc_ops = GetCUDALayoutSensitiveOps ();
100- if (!cuda_nhwc_ops.count (node.OpType ())) {
101- return false ;
102- }
103- }
104- }
105- #endif
106-
107- // TODO: We don't really need EP pre-processor macros in this function because we're already checking if the
108- // node is assigned to the desired EP (e.g., QNN EP). There's nothing about this code that absolutely requires
109- // conditional compilation.
110- #if defined(USE_QNN) || defined(USE_QNN_PROVIDER_INTERFACE)
111- if (node.GetExecutionProviderType () == kQnnExecutionProvider ) {
112- if (node.OpType () == " Upsample" ) {
113- // Upsample is translated to QNN's Resize, which requires the NHWC layout for processing.
114- return true ;
115- }
116- }
117- #endif
118-
119- return layout_sensitive_ops.count (node.OpType ()) != 0 ;
53+ const auto & layout_sensitive_ops = GetORTLayoutSensitiveOps ();
54+ const auto op_identifier = MakeORTLayoutSensitiveOpId (domain, op_type);
55+ return layout_sensitive_ops.find (op_identifier) != layout_sensitive_ops.end ();
12056}
12157} // namespace
12258
@@ -126,25 +62,37 @@ bool ConvertNodeLayout(const api::NodeRef& node) {
12662// Once all the layout sensitive ops requested by the EP are wrapped the transpose optimizer will attempt to remove
12763// as many of the layout transposes as possible.
12864const std::unordered_set<std::string_view>& GetORTLayoutSensitiveOps () {
129- static std::unordered_set<std::string_view> ort_layout_sensitive_ops = []() {
130- const auto & layout_sensitive_ops = onnx_transpose_optimization::GetLayoutSensitiveOps ();
65+ static const std::unordered_set<std::string_view> ort_layout_sensitive_ops = []() {
66+ const auto & layout_sensitive_onnx_ops = onnx_transpose_optimization::GetLayoutSensitiveOps ();
67+
68+ // Define a static local string array so we can refer to the elements with string_views.
69+ static const std::string layout_sensitive_contrib_ops[]{
70+ MakeORTLayoutSensitiveOpId (kMSDomain , " FusedConv" ),
71+ MakeORTLayoutSensitiveOpId (kMSDomain , " GridSample" ),
72+ MakeORTLayoutSensitiveOpId (kMSDomain , " QLinearAveragePool" ),
73+ MakeORTLayoutSensitiveOpId (kMSDomain , " QLinearGlobalAveragePool" ),
74+ };
75+
13176 std::unordered_set<std::string_view> ort_specific_ops =
13277 {
133- " FusedConv" ,
134- " QLinearAveragePool" ,
135- " QLinearGlobalAveragePool" ,
13678 // Whilst the ONNX spec doesn't specify a layout for Resize, we treat it as layout sensitive by default
13779 // as EPs tend to only support one layout.
13880 " Resize" ,
13981 };
14082
141- ort_specific_ops.insert (layout_sensitive_ops.cbegin (), layout_sensitive_ops.cend ());
83+ ort_specific_ops.insert (std::begin (layout_sensitive_onnx_ops), std::end (layout_sensitive_onnx_ops));
84+ ort_specific_ops.insert (std::begin (layout_sensitive_contrib_ops), std::end (layout_sensitive_contrib_ops));
14285 return ort_specific_ops;
14386 }();
14487
14588 return ort_layout_sensitive_ops;
14689}
14790
91+ // "op_type" if from ONNX domain, "domain:op_type" otherwise.
92+ std::string MakeORTLayoutSensitiveOpId (std::string_view domain, std::string_view op_type) {
93+ return (domain == kOnnxDomain ) ? std::string (op_type) : MakeString (domain, " :" , op_type);
94+ }
95+
14896Status TransformLayoutForEP (Graph& graph, bool & modified, const IExecutionProvider& execution_provider,
14997 AllocatorPtr cpu_allocator,
15098 const DebugGraphFn& debug_graph_fn) {
@@ -159,7 +107,7 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid
159107 continue ;
160108 }
161109
162- if (ConvertNodeLayout ( *node)) {
110+ if (ShouldConvertNodeLayoutToNhwc (execution_provider, *node)) {
163111 // domain kMSInternalNHWCDomain uses OpType "Conv" for both Conv and FusedConv.
164112 // So, change the OpType to "Conv" for FusedConv.
165113 std::string_view op_type = node->OpType () == " FusedConv" ? " Conv" : node->OpType ();
0 commit comments