Skip to content

Commit 330c587

Browse files
authored
[BACKPORT] Allow softmax type conversion to happen before or after elementwise (#1920)
Allow softmax type conversion to happen before or after elementwise ops in attention (#1911) For the attention kernel, migraphx can decide to do softmax in higher precision. It would insert necessary convert operation to cast results back and forth around softmax ops. But these casts could appear before or after any elementwise scaling ops that happen before softmax. Therefore this PR extends pattern matcher to allow such cases. For example before, it only handled cases like : gemm1 -> elementwise ops -> convert(to_higher_precision) -> softmax -> convert(back_to_lower) -> gemm2 With this PR, it can also handle case where it is like following : gemm1 -> convert(to_higher_precision) -> elementwise ops -> softmax -> convert(back_to_lower) -> gemm2
1 parent 221ad7a commit 330c587

File tree

51 files changed

+1338
-297
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1338
-297
lines changed

mlir/include/mlir/Dialect/MIGraphX/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ namespace migraphx {
2222

2323
#define GEN_PASS_DECL_MIGRAPHXREALIZEINT4PASS
2424
#define GEN_PASS_DECL_MIGRAPHXTRANSFORMPASS
25+
#define GEN_PASS_DECL_MIGRAPHXTOSASIMPLIFYPASS
2526
#define GEN_PASS_REGISTRATION
2627
#include "mlir/Dialect/MIGraphX/Passes.h.inc"
2728

mlir/include/mlir/Dialect/MIGraphX/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,13 @@ def MIGraphXRealizeInt4Pass : Pass<"migraphx-realize-int4", "func::FuncOp"> {
4141
}];
4242
}
4343

44+
def MIGraphXTosaSimplifyPass : Pass<"migraphx-tosa-simplify", "func::FuncOp"> {
45+
let summary = "Simplify TOSA operations after MIGraphX to TOSA conversion";
46+
let description = [{
47+
This pass simplifies TOSA operations after converting migraphx ops to tosa, optimizing
48+
the representation of operations for better performance and clarity.
49+
}];
50+
let dependentDialects = ["tosa::TosaDialect", "func::FuncDialect"];
51+
}
52+
4453
#endif // MLIR_DIALECT_MIGRAPHX_PASSES

0 commit comments

Comments
 (0)