Skip to content

Commit 45ffd99

Browse files
Compile API: disable optimizations by default (microsoft#25474)
### Description - Disables graph optimizations by default when using the explicit compiling API. - Adds `ModelCompilationOptions_SetGraphOptimizationLevel` to allow the user to set an optimization level. - Adds C++, Python, and C# bindings for the new API function. - Updates `ModelCompilationOptions_SetFlags` to take in a `uint32_t flags` parameter instead of `size_t flags` to ensure the same size across platforms. This API is not yet in a public ORT release, so safe to modify. ### Motivation and Context When compiling, prefer allowing the EP to do the optimizations instead of ORT.
1 parent 5537d33 commit 45ffd99

File tree

16 files changed

+222
-12
lines changed

16 files changed

+222
-12
lines changed

csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ public class OrtModelCompilationOptions : SafeHandle
2727
/// <summary>
2828
/// Create a new OrtModelCompilationOptions object from SessionOptions.
2929
/// </summary>
30+
/// <remarks>By default, the GraphOptimizationLevel is set to ORT_DISABLE_ALL. Use SetGraphOptimizationLevel()
31+
/// to enable graph optimizations.</remarks>
3032
/// <param name="sessionOptions">SessionOptions instance to read settings from.</param>
3133
public OrtModelCompilationOptions(SessionOptions sessionOptions)
3234
: base(IntPtr.Zero, true)
@@ -130,6 +132,33 @@ public void SetFlags(OrtCompileApiFlags flags)
130132
NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(handle, (uint)flags));
131133
}
132134

135+
/// <summary>
136+
/// Sets information related to EP context binary file. The Ep uses this information to decide the
137+
/// location and context binary file name when compiling with both the input and output models
138+
/// stored in buffers.
139+
/// </summary>
140+
/// <param name="outputDirectory">Path to the model directory.</param>
141+
/// <param name="modelName">The name of the model.</param>
142+
public void SetEpContextBinaryInformation(string outputDirectory, string modelName)
143+
{
144+
var platformOutputDirectory = NativeOnnxValueHelper.GetPlatformSerializedString(outputDirectory);
145+
var platformModelName = NativeOnnxValueHelper.GetPlatformSerializedString(modelName);
146+
NativeApiStatus.VerifySuccess(
147+
NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextBinaryInformation(
148+
handle, platformOutputDirectory, platformModelName));
149+
}
150+
151+
/// <summary>
152+
/// Sets the graph optimization level. Defaults to ORT_DISABLE_ALL if not specified.
153+
/// </summary>
154+
/// <param name="graphOptimizationLevel">The graph optimization level to set.</param>
155+
public void SetGraphOptimizationLevel(GraphOptimizationLevel graphOptimizationLevel)
156+
{
157+
NativeApiStatus.VerifySuccess(
158+
NativeMethods.CompileApi.OrtModelCompilationOptions_SetGraphOptimizationLevel(
159+
handle, graphOptimizationLevel));
160+
}
161+
133162
internal IntPtr Handle => handle;
134163

135164

csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ public struct OrtCompileApi
2121
public IntPtr ModelCompilationOptions_SetEpContextEmbedMode;
2222
public IntPtr CompileModel;
2323
public IntPtr ModelCompilationOptions_SetFlags;
24+
public IntPtr ModelCompilationOptions_SetEpContextBinaryInformation;
25+
public IntPtr ModelCompilationOptions_SetGraphOptimizationLevel;
2426
}
2527

2628
internal class NativeMethods
@@ -101,6 +103,21 @@ public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile
101103
uint flags);
102104
public DOrtModelCompilationOptions_SetFlags OrtModelCompilationOptions_SetFlags;
103105

106+
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
107+
public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetEpContextBinaryInformation(
108+
IntPtr /* OrtModelCompilationOptions* */ options,
109+
byte[] /* const ORTCHAR_T* */ outputDirectory,
110+
byte[] /* const ORTCHAR_T* */ modelName);
111+
public DOrtModelCompilationOptions_SetEpContextBinaryInformation
112+
OrtModelCompilationOptions_SetEpContextBinaryInformation;
113+
114+
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
115+
public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetGraphOptimizationLevel(
116+
IntPtr /* OrtModelCompilationOptions* */ options,
117+
GraphOptimizationLevel graphOptimizationLevel);
118+
public DOrtModelCompilationOptions_SetGraphOptimizationLevel
119+
OrtModelCompilationOptions_SetGraphOptimizationLevel;
120+
104121
internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi)
105122
{
106123

@@ -161,6 +178,16 @@ internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi
161178
_compileApi.ModelCompilationOptions_SetFlags,
162179
typeof(DOrtModelCompilationOptions_SetFlags));
163180

181+
OrtModelCompilationOptions_SetEpContextBinaryInformation =
182+
(DOrtModelCompilationOptions_SetEpContextBinaryInformation)Marshal.GetDelegateForFunctionPointer(
183+
_compileApi.ModelCompilationOptions_SetEpContextBinaryInformation,
184+
typeof(DOrtModelCompilationOptions_SetEpContextBinaryInformation));
185+
186+
OrtModelCompilationOptions_SetGraphOptimizationLevel =
187+
(DOrtModelCompilationOptions_SetGraphOptimizationLevel)Marshal.GetDelegateForFunctionPointer(
188+
_compileApi.ModelCompilationOptions_SetGraphOptimizationLevel,
189+
typeof(DOrtModelCompilationOptions_SetGraphOptimizationLevel));
190+
164191
}
165192
}
166193
}

csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ public void BasicUsage()
3030

3131
compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512);
3232
compileOptions.SetEpContextEmbedMode(true);
33+
compileOptions.SetGraphOptimizationLevel(GraphOptimizationLevel.ORT_ENABLE_BASIC);
3334

3435
}
3536

@@ -45,6 +46,7 @@ public void BasicUsage()
4546
UIntPtr bytesSize = new UIntPtr();
4647
var allocator = OrtAllocator.DefaultInstance;
4748
compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize);
49+
compileOptions.SetEpContextBinaryInformation("./", "squeezenet.onnx");
4850

4951
compileOptions.CompileModel();
5052

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7074,6 +7074,9 @@ struct OrtCompileApi {
70747074
* ReleaseOrtModelCompilationsOptions must be called to free the OrtModelCompilationOptions after calling
70757075
* CompileModel.
70767076
*
7077+
* \note By default, the GraphOptimizationLevel is set to ORT_DISABLE_ALL. Use
7078+
* ModelCompilationOptions_SetGraphOptimizationLevel to enable graph optimizations.
7079+
*
70777080
* \param[in] env OrtEnv object.
70787081
* \param[in] session_options The OrtSessionOptions instance from which to create the OrtModelCompilationOptions.
70797082
* \param[out] out The created OrtModelCompilationOptions instance.
@@ -7230,7 +7233,7 @@ struct OrtCompileApi {
72307233
* \since Version 1.23.
72317234
*/
72327235
ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options,
7233-
size_t flags);
7236+
uint32_t flags);
72347237

72357238
/** Sets information related to EP context binary file.
72367239
*
@@ -7249,6 +7252,19 @@ struct OrtCompileApi {
72497252
_In_ OrtModelCompilationOptions* model_compile_options,
72507253
_In_ const ORTCHAR_T* output_directory,
72517254
_In_ const ORTCHAR_T* model_name);
7255+
7256+
/** Set the graph optimization level.
7257+
*
7258+
* \param[in] model_compile_options The OrtModelCompilationOptions instance.
7259+
* \param[in] graph_optimization_level The graph optimization level.
7260+
*
7261+
* \snippet{doc} snippets.dox OrtStatus Return Value
7262+
*
7263+
* \since Version 1.23.
7264+
*/
7265+
ORT_API2_STATUS(ModelCompilationOptions_SetGraphOptimizationLevel,
7266+
_In_ OrtModelCompilationOptions* model_compile_options,
7267+
_In_ GraphOptimizationLevel graph_optimization_level);
72527268
};
72537269

72547270
/*

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1424,7 +1424,9 @@ struct ModelCompilationOptions : detail::Base<OrtModelCompilationOptions> {
14241424
size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer
14251425
ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory,
14261426
const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation
1427-
ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags
1427+
ModelCompilationOptions& SetFlags(uint32_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags
1428+
1429+
ModelCompilationOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::ModelCompilationOptions_SetGraphOptimizationLevel
14281430
};
14291431

14301432
/** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels.

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1019,11 +1019,18 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode(
10191019
return *this;
10201020
}
10211021

1022-
inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(size_t flags) {
1022+
inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(uint32_t flags) {
10231023
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetFlags(this->p_, flags));
10241024
return *this;
10251025
}
10261026

1027+
inline ModelCompilationOptions& ModelCompilationOptions::SetGraphOptimizationLevel(
1028+
GraphOptimizationLevel graph_optimization_level) {
1029+
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetGraphOptimizationLevel(this->p_,
1030+
graph_optimization_level));
1031+
return *this;
1032+
}
1033+
10271034
namespace detail {
10281035

10291036
template <typename T>

onnxruntime/core/session/compile_api.cc

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode
231231
}
232232

233233
ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetFlags,
234-
_In_ OrtModelCompilationOptions* ort_model_compile_options, size_t flags) {
234+
_In_ OrtModelCompilationOptions* ort_model_compile_options, uint32_t flags) {
235235
API_IMPL_BEGIN
236236
#if !defined(ORT_MINIMAL_BUILD)
237237
auto model_compile_options = reinterpret_cast<onnxruntime::ModelCompilationOptions*>(ort_model_compile_options);
@@ -245,6 +245,22 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetFlags,
245245
API_IMPL_END
246246
}
247247

248+
ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel,
249+
_In_ OrtModelCompilationOptions* ort_model_compile_options,
250+
_In_ GraphOptimizationLevel graph_optimization_level) {
251+
API_IMPL_BEGIN
252+
#if !defined(ORT_MINIMAL_BUILD)
253+
auto model_compile_options = reinterpret_cast<onnxruntime::ModelCompilationOptions*>(ort_model_compile_options);
254+
ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetGraphOptimizationLevel(graph_optimization_level));
255+
return nullptr;
256+
#else
257+
ORT_UNUSED_PARAMETER(ort_model_compile_options);
258+
ORT_UNUSED_PARAMETER(graph_optimization_level);
259+
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build");
260+
#endif // !defined(ORT_MINIMAL_BUILD)
261+
API_IMPL_END
262+
}
263+
248264
ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env,
249265
_In_ const OrtModelCompilationOptions* ort_model_compile_options) {
250266
API_IMPL_BEGIN
@@ -278,6 +294,7 @@ static constexpr OrtCompileApi ort_compile_api = {
278294

279295
&OrtCompileAPI::ModelCompilationOptions_SetFlags,
280296
&OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation,
297+
&OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel,
281298
};
282299

283300
// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned

onnxruntime/core/session/compile_api.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModel
2929
bool embed_ep_context_in_model);
3030
ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options);
3131
ORT_API_STATUS_IMPL(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_options,
32-
size_t flags);
32+
uint32_t flags);
3333
ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextBinaryInformation, _In_ OrtModelCompilationOptions* model_compile_options,
3434
_In_ const ORTCHAR_T* output_dir, _In_ const ORTCHAR_T* model_name);
35+
ORT_API_STATUS_IMPL(ModelCompilationOptions_SetGraphOptimizationLevel,
36+
_In_ OrtModelCompilationOptions* model_compile_options,
37+
_In_ GraphOptimizationLevel graph_optimization_level);
3538

3639
} // namespace OrtCompileAPI

onnxruntime/core/session/model_compilation_options.cc

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment&
2727
// Shouldn't fail because the key/value strings are below the maximum string length limits in ConfigOptions.
2828
ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1").IsOK());
2929
ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionsDisableModelCompile, "0").IsOK());
30+
31+
session_options_.value.graph_optimization_level = TransformerLevel::Default; // L0: required transformers only
3032
}
3133

3234
void ModelCompilationOptions::SetInputModelPath(const std::string& input_model_path) {
@@ -135,7 +137,7 @@ Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_m
135137
return Status::OK();
136138
}
137139

138-
Status ModelCompilationOptions::SetFlags(size_t flags) {
140+
Status ModelCompilationOptions::SetFlags(uint32_t flags) {
139141
EpContextModelGenerationOptions& options = session_options_.value.ep_context_gen_options;
140142
options.error_if_output_file_exists = flags & OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS;
141143
options.action_if_no_compiled_nodes =
@@ -170,6 +172,34 @@ void ModelCompilationOptions::ResetInputModelSettings() {
170172
input_model_data_size_ = 0;
171173
}
172174

175+
Status ModelCompilationOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
176+
switch (graph_optimization_level) {
177+
case ORT_DISABLE_ALL:
178+
// TransformerLevel::Default means that we only run required transformers.
179+
session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Default;
180+
break;
181+
case ORT_ENABLE_BASIC:
182+
session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level1;
183+
break;
184+
case ORT_ENABLE_EXTENDED:
185+
session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level2;
186+
break;
187+
case ORT_ENABLE_LAYOUT:
188+
session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level3;
189+
break;
190+
case ORT_ENABLE_ALL:
191+
session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::MaxLevel;
192+
break;
193+
default:
194+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "graph_optimization_level with value ",
195+
static_cast<int>(graph_optimization_level), " is invalid. Valid values are: ",
196+
"ORT_DISABLE_ALL (0), ORT_ENABLE_BASIC (1), ORT_ENABLE_EXTENDED (2), ",
197+
"ORT_ENABLE_LAYOUT (3), and ORT_ENABLE_ALL (99).");
198+
}
199+
200+
return Status::OK();
201+
}
202+
173203
Status ModelCompilationOptions::ResetOutputModelSettings() {
174204
EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options;
175205
ep_context_gen_options.output_model_file_path.clear();

onnxruntime/core/session/model_compilation_options.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class ModelCompilationOptions {
9595
/// </summary>
9696
/// <param name="flags">unsigned integer set to the bitwise OR of enabled flags.</param>
9797
/// <returns>Status indicating success or an error</returns>
98-
Status SetFlags(size_t flags);
98+
Status SetFlags(uint32_t flags);
9999

100100
/// <summary>
101101
/// Returns a reference to the session options object.
@@ -129,6 +129,13 @@ class ModelCompilationOptions {
129129
/// <returns>input model buffer's size in bytes</returns>
130130
size_t GetInputModelDataSize() const;
131131

132+
/// <summary>
133+
/// Sets the graph optimization level for the underlying session that compiles the model.
134+
/// </summary>
135+
/// <param name="graph_optimization_level">The optimization level</param>
136+
/// <returns></returns>
137+
Status SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);
138+
132139
/// <summary>
133140
/// Checks if the compilation options described by this object are valid.
134141
/// </summary>

0 commit comments

Comments
 (0)