Skip to content

Commit bde336e

Browse files
authored
Update tosa.mul op in StableHLO-to-TOSA Pass in compliance with MLIR upstream and TOSA-v1.0 (#2702)
This PR updates the StableHLO to TOSA legalization pass to align with the recent changes in the MLIR upstream. Specifically, [the `tosa::MulOp` operation has being modified to comply with the TOSA-v1.0 specification](llvm/llvm-project#121953). The shift parameter of the MUL operator, which was previously an attribute, has been moved to an SSA operand.  The upstream changes caused a compilation failure in the StableHLO to TOSA conversion pass because the `tosa.mul` operation now expects 3 operand groups instead of 2, as reflected in the error message: ```jsx error: invalid number of operand groups for `tosa.mul`; expected 3, but got 2 with op<tosa.mul>(input0, input1) {shift = attr<"0 : i8">}; ^ ``` This PR added a zero constant as the additional argument for the shift parameter to maintain compatibility with the updated `tosa.mul` operation.
1 parent b62dc66 commit bde336e

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

build_tools/llvm_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
aa65f93b71dee8cacb22be1957673c8be6a3ec24
1+
5c24847e7dba01dde230e18b39a3074022279c89

stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
#include "mlir/Dialect/Tosa/IR/TosaOps.td"
1616
#include "stablehlo/dialect/StablehloOps.td"
1717

18+
Rewrite zeroConst() -> Op [{
19+
auto type = rewriter.getI8Type();
20+
auto attr = mlir::DenseElementsAttr::get(
21+
llvm::cast<mlir::ShapedType>(type), rewriter.getZeroAttr(type));
22+
return rewriter.create<mlir::tosa::ConstOp>(
23+
rewriter.getUnknownLoc(), type, attr);
24+
}];
25+
1826
// Helper functions.
1927
Rewrite onesLike(op: Op, type: Type) -> Op [{
2028
auto elementType = llvm::cast<mlir::TensorType>(type).getElementType();
@@ -137,7 +145,7 @@ Pattern =>
137145
Pattern =>
138146
replace op<stablehlo.multiply>(input0 : Value<_: Tosa_Tensor>,
139147
input1 : Value<_: Tosa_Tensor>)
140-
with op<tosa.mul>(input0, input1) {shift = attr<"0 : i8">};
148+
with op<tosa.mul>(input0, input1, zeroConst());
141149
Pattern =>
142150
replace op<stablehlo.or>(input0 : Value<_: Tosa_Tensor>,
143151
input1 : Value<_: Tosa_Tensor>)

0 commit comments

Comments
 (0)