You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
### Description
<!-- Describe your changes. -->
We might have a case where multiple Cast nodes in the chain cast back to
the original type. This fusion will remove extra nodes.
E.g.
`A ('float32') -> Cast (to='float16') -> Cast (to='int4') -> Cast
(to='float32') -> Cast (to='float16') -> B
`
will reduce to
` A ('float32') -> Cast (to='float16') -> B
`
All the Cast nodes throughout the path need to have one input and one
output to be considered for the fusion.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Gemma3 ONNX models used to have double casting, and many new models
created by the model builder might have as well. Extra Casts might
reduce accuracy and increase inference time.
// Copyright (c) Microsoft Corporation. All rights reserved.
2
+
// Licensed under the MIT License.
3
+
4
+
#pragma once
5
+
6
+
#include"core/optimizer/rewrite_rule.h"
7
+
8
+
namespaceonnxruntime {
9
+
10
+
/**
11
+
@Class CastElimination
12
+
The transform that will try to find the longest chain of the type Cast where the 'to' attribute has the same data type as the input of the first Cast node in the chain.
13
+
E.g.
14
+
A ('float32') -> Cast (to='float16') -> Cast (to='int4') -> Cast (to='float32') -> Cast (to='float16') -> B
15
+
will reduce to
16
+
A ('float32') -> Cast (to='float16') -> B
17
+
18
+
All the Cast nodes throughout the path need to have one input and one output to be considered for the fusion.
0 commit comments