Skip to content

Commit 340b188

Browse files
Fusing Initializers with Graph Transforms (microsoft#24726)
### Description Added a graph transform for mixed precision graphs when FP16 compute is unavailable. At session creation, this graph transform converts FP16 initializers (_which were changed to FP16 to FP32 cast nodes_) to FP32 initializers and fuses them with their next FP32 nodes. - Behavior before this change: "fp16 initializers -> cast_from_fp16_to_fp32 -> fp32 node/s" - Behavior after this change: "fp16 initializers converted to fp32 initializers then fused with fp32 node/s" ### Motivation and Context This change aims to run the FP16 models without the repetitive casting of FP16 initializers to FP32 initializers, by fusing FP32 initializers with their next nodes, when FP16 compute is not available. > For naming purposes, the newly added Graph Transforms in long form is called "Fused Initializers Graph Transforms", and in short form is called "FIGT". ### Working Currently, the Fuse Initializers Graph Transform fuses cast nodes that casts from FP16 to FP32, back to their next/output nodes. Below is an explanation of how this transforms works. It depends on ```InsertCastTransforms``` to produce the intermediate representation from which it fuses the initializers (which are the cast node with zero input, one initializer, and one output) back to the next/output node. After fusion, the link/edge between such a cast node to the next/output node will then be removed. Cast nodes will be removed as well. ``` "Input Graph" "Intermediate Representation" "FIGT Transforms" -------- -------- -------- -------- -------- | X_Fp16 | | X_Fp16 | | W_Fp16 | | B_Fp16 | | X_Fp16 | -------- -------- -------- -------- -------- | | | | | | | | | | | V V V V | | Cast | | Cast | | Cast | | Cast | | | Fp16 | | Fp16 | | Fp16 | | Fp16 | | | To | | To | | To | | To | | | Fp32 | | Fp32 | | Fp32 | | Fp32 | | | | | | | | | | | V V V V V ---------------------------- ----------------------------------------- ---------------------------- | Conv_Fp16 | | | | Conv_Fp32 | | --W_Fp16-- | ==> | Conv_Fp32 | ==> | --W_Fp32-- | | --B_Fp16-- | | | | --B_Fp32-- | ---------------------------- ----------------------------------------- ---------------------------- | | | | | | | V V | | Cast | | Cast | | | Fp32 | | Fp32 | | | To | | To | | | Fp16 | | Fp16 | | | | | | | V V V -------- -------- -------- | Y_Fp16 | | Y_Fp16 | | Y_Fp16 | -------- -------- -------- ``` The newly added Graph Transforms perform the following actions. * Detect Cast node/s with single FP16 initializer converting to FP32. * Convert all such FP16 initializer/s to FP32 initializer/s. * Fuse newly created FP32 initializer/s to relative FP32 node/s. * Remove FP16 to FP32 Cast node/s. This is run in a loop as follows. It excludes Level 1 and Partitioning optimizations. ``` Level 2 --> Level 3 --> InsertCastTransforms --> FIGT ^ | | "LOOP" | | | ------------------------------------------------- ``` ### Adding FIGT as a Level-4 Graph Transform. This will have the following benefits. 1. Ability to turn off (any/all) the Level 4 Optimizations. We can use the `disable optimizers` functionality to turn off one of such optimizations during testing, or use the `-o` switch to turn off all Level 4 optimizations while executing a model using the command line or Python scripts (or any other scripts). 2. Ability to rerun Level 2 and Level 3 optimizations remains intact after Level 4 Optimizations are applied. Adding Level 4 takes care that FIGT (or any similar optimizations) will always run after InsertCastNodes. 3. It keeps the current graph manipulations untouched and gives us more flexibility to add future optimizations like adding `Int8 to Int32` upconvert or `FP8 to FP16` upconvert under Level 4. Level 4 can, as of now, work as a placeholder for any other such upcoming Graph optimizations. ``` Level 2 --> Level 3 --> InsertCastTransforms --> Level 4 ^ | | "LOOP" | | | -------------------------------------------------- ``` > Added a placeholder for Level-4 for graph transforms utils under orttraining. This helps resolve any exceptions that may be encountered during training sessions. #### Re-running Level 2+ optimizations after Level 4 / FIGT The idea behind re-running Level2+ graph transforms is that, after the fusion of initializers with their respective nodes, the nodes are now in a format that might be supported by other graph transforms that were previously skipped. Hence, some of the transformations previously unable to be applied are now valid and can be applied to create a more optimal graph for execution. ### Added a new session option "kOrtSessionOptionsGraphOptimizationsLoopLevel" to handle the graph optimization loop. * When set to 2 or above it will loop until no more optimizations are applied at any level starting Level 2 and above. ``` Level 2 --> Level 3 --> InsertCastTransforms --> Level 4 ^ | | "Loop" | | | --------------------------------------------------- ``` * When set to 1 (default) it will loop until no more optimizations are applied at Level 4 only. ``` Level 2 --> Level 3 --> InsertCastTransforms --> Level 4 ^ | | "Loop only depending on Level 4" | | | --------------------------------------------------- ``` * When set to 0 it disables the loop. ``` Level 2 --> Level 3 --> InsertCastTransforms --> Level 4 ^ | | "No Loop" | | | X xxxxxxxxxxx X ``` ### Documentation We have not added any details related to Level 4 in the [Graph Optimizations in ONNX Runtime](https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html) documentation. ### OLD PR This PR is created following a thorough discussion on the [OLD PR](microsoft#24175). Signed-off-by: Sunny Shukla <[email protected]>
1 parent 53ee6c5 commit 340b188

File tree

48 files changed

+1215
-81
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1215
-81
lines changed

csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ public enum GraphOptimizationLevel
1717
ORT_DISABLE_ALL = 0,
1818
ORT_ENABLE_BASIC = 1,
1919
ORT_ENABLE_EXTENDED = 2,
20+
ORT_ENABLE_LAYOUT = 3,
2021
ORT_ENABLE_ALL = 99
2122
}
2223

include/onnxruntime/core/optimizer/graph_transformer_level.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ enum class TransformerLevel : int {
1212
Level1, // basic optimizations
1313
Level2, // extended optimizations
1414
Level3, // layout optimizations
15+
Level4, // unsupported datatypes optimizations
1516
// The max level should always be same as the last level.
16-
MaxLevel = Level3
17+
MaxLevel = Level4
1718
};
1819

1920
} // namespace onnxruntime

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ typedef enum GraphOptimizationLevel {
355355
ORT_DISABLE_ALL = 0,
356356
ORT_ENABLE_BASIC = 1,
357357
ORT_ENABLE_EXTENDED = 2,
358+
ORT_ENABLE_LAYOUT = 3,
358359
ORT_ENABLE_ALL = 99
359360
} GraphOptimizationLevel;
360361

include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,37 @@ static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimiz
107107
// Default is an empty string which means no optimizers are disabled.
108108
static const char* const kOrtSessionOptionsDisableSpecifiedOptimizers = "optimization.disable_specified_optimizers";
109109

110+
// It controls whether to run graph optimizations in loop or not.
111+
//
112+
// "0": disable. Graph Optimization Loop is disabled.
113+
// ```
114+
// Level 2 --> Level 3 --> InsertCastTransforms --> Level 4
115+
// ^ |
116+
// | "No Loop" |
117+
// | |
118+
// X xxxxxxxxxxx X
119+
// ```
120+
// "1": enable. Graph Optimization Loop is enabled, such that, if optimizations at Level 4 are applied then
121+
// the loop will check for any other valid optimization that can happen.
122+
// ```
123+
// Level 2 --> Level 3 --> InsertCastTransforms --> Level 4
124+
// ^ |
125+
// | "Loop only depending on Level 4" |
126+
// | |
127+
// ---------------------------------------------------
128+
// ```
129+
// "2": enable. Graph Optimization Loop is enabled, such that, if optimizations at Level 2 or above are applied then
130+
// The loop will check for any other valid optimization that can happen.
131+
// ```
132+
// Level 2 --> Level 3 --> InsertCastTransforms --> Level 4
133+
// ^ |
134+
// | "Loop" |
135+
// | |
136+
// ---------------------------------------------------
137+
// ```
138+
// Default value is set to "1".
139+
static const char* const kOrtSessionOptionsGraphOptimizationsLoopLevel = "session.graph_optimizations_loop_level";
140+
110141
// Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0".
111142
// Using device allocators means the memory allocation is made using malloc/new.
112143
static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers";

java/src/main/java/ai/onnxruntime/OrtSession.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,8 @@ public enum OptLevel {
652652
* graph.
653653
*/
654654
EXTENDED_OPT(2),
655+
/** Applies all the layout optimizations like NCHW and NCHWC to the ONNX graph. */
656+
LAYOUT_OPT(3),
655657
/** Applies all available optimizations to the ONNX graph. */
656658
ALL_OPT(99);
657659

java/src/main/native/OrtJniUtil.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ GraphOptimizationLevel convertOptimizationLevel(jint level) {
4747
return ORT_ENABLE_BASIC;
4848
case 2:
4949
return ORT_ENABLE_EXTENDED;
50+
case 3:
51+
return ORT_ENABLE_LAYOUT;
5052
case 99:
5153
return ORT_ENABLE_ALL;
5254
default:

js/common/lib/inference-session.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ export declare namespace InferenceSession {
8181
*
8282
* This setting is available only in ONNXRuntime (Node.js binding and react-native) or WebAssembly backend
8383
*/
84-
graphOptimizationLevel?: 'disabled' | 'basic' | 'extended' | 'all';
84+
graphOptimizationLevel?: 'disabled' | 'basic' | 'extended' | 'layout' | 'all';
8585

8686
/**
8787
* Whether enable CPU memory arena.

js/node/src/session_options_helper.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ const std::unordered_map<std::string, GraphOptimizationLevel> GRAPH_OPT_LEVEL_NA
3131
{"disabled", ORT_DISABLE_ALL},
3232
{"basic", ORT_ENABLE_BASIC},
3333
{"extended", ORT_ENABLE_EXTENDED},
34+
{"layout", ORT_ENABLE_LAYOUT},
3435
{"all", ORT_ENABLE_ALL}};
3536

3637
const std::unordered_map<std::string, ExecutionMode> EXECUTION_MODE_NAME_TO_ID_MAP = {{"sequential", ORT_SEQUENTIAL},

js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ public WritableMap run(String key, ReadableMap input, ReadableArray output, Read
326326
{"disabled", SessionOptions.OptLevel.NO_OPT},
327327
{"basic", SessionOptions.OptLevel.BASIC_OPT},
328328
{"extended", SessionOptions.OptLevel.EXTENDED_OPT},
329+
{"layout", SessionOptions.OptLevel.LAYOUT_OPT},
329330
{"all", SessionOptions.OptLevel.ALL_OPT},
330331
})
331332
.collect(Collectors.toMap(p -> (String)p[0], p -> (SessionOptions.OptLevel)p[1]));

js/react_native/ios/OnnxruntimeModule.mm

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ - (NSDictionary*)run:(NSString*)url
301301
@"disabled" : @(ORT_DISABLE_ALL),
302302
@"basic" : @(ORT_ENABLE_BASIC),
303303
@"extended" : @(ORT_ENABLE_EXTENDED),
304+
@"layout" : @(ORT_ENABLE_LAYOUT),
304305
@"all" : @(ORT_ENABLE_ALL)
305306
};
306307

0 commit comments

Comments
 (0)