File tree Expand file tree Collapse file tree 1 file changed +2
-0
lines changed
onnxscript/rewriter/ort_fusions Expand file tree Collapse file tree 1 file changed +2
-0
lines changed Original file line number Diff line number Diff line change 2929 fuse_rotary_embedding ,
3030)
3131from onnxscript .rewriter .ort_fusions .sdpa import fuse_sdpa
32+ from onnxscript .rewriter .ort_fusions .sdpa_via_mha import replace_sdpa_by_mha
3233from onnxscript .rewriter .ort_fusions .skip_normalization import (
3334 fuse_skip_layer_normalization ,
3435 fuse_skip_rms_normalization ,
@@ -104,6 +105,7 @@ def fuse(func, **kwargs):
104105 fusion_count ["attention" ] = fuse (fuse_attention )
105106 fusion_count ["gelu" ] = fuse (fuse_gelu )
106107 fusion_count ["bias_gelu" ] = fuse (fuse_bias_gelu )
108+ fusion_count ["sdpa_via_mha" ] = fuse (replace_sdpa_by_mha )
107109 # Finally: inline any intermediate fusion functions introduced that were not
108110 # consumed by other fusions, and eliminate any remaining unused nodes.
109111 optimize (model )
You can’t perform that action at this time.
0 commit comments