|
15 | 15 | #include "mlir/Dialect/Tosa/IR/TosaOps.td" |
16 | 16 | #include "stablehlo/dialect/StablehloOps.td" |
17 | 17 |
|
18 | | -Rewrite zeroConst() -> Op [{ |
19 | | - auto type = rewriter.getI8Type(); |
20 | | - auto attr = mlir::DenseElementsAttr::get( |
21 | | - llvm::cast<mlir::ShapedType>(type), rewriter.getZeroAttr(type)); |
| 18 | +// Helper functions. |
| 19 | +Rewrite changeElementTypeToI1(type: Type) -> Type [{ |
| 20 | + auto tensorType = llvm::cast<mlir::RankedTensorType>(type); |
| 21 | + return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type()); |
| 22 | +}]; |
| 23 | + |
| 24 | +Rewrite changeElementTypeToI8(type: Type) -> Type [{ |
| 25 | + auto tensorType = llvm::cast<mlir::RankedTensorType>(type); |
| 26 | + return RankedTensorType::get(tensorType.getShape(), rewriter.getI8Type()); |
| 27 | +}]; |
| 28 | + |
| 29 | +Rewrite zerosLike(op: Op, type: Type) -> Op [{ |
| 30 | + auto elementType = llvm::cast<mlir::TensorType>(type).getElementType(); |
| 31 | + llvm::SmallVector<mlir::Attribute, 4> outputValue; |
| 32 | + |
| 33 | + if (elementType.isF16() || elementType.isF32() || elementType.isBF16()) { |
| 34 | + outputValue.push_back(rewriter.getFloatAttr(elementType, 0)); |
| 35 | + } else { |
| 36 | + outputValue.push_back(rewriter.getIntegerAttr(elementType, 0)); |
| 37 | + } |
| 38 | + |
22 | 39 | return rewriter.create<mlir::tosa::ConstOp>( |
23 | | - rewriter.getUnknownLoc(), type, attr); |
| 40 | + op->getLoc(), type, |
| 41 | + mlir::DenseElementsAttr::get( |
| 42 | + llvm::cast<mlir::ShapedType>(type), outputValue)); |
24 | 43 | }]; |
25 | 44 |
|
26 | | -// Helper functions. |
27 | 45 | Rewrite onesLike(op: Op, type: Type) -> Op [{ |
28 | 46 | auto elementType = llvm::cast<mlir::TensorType>(type).getElementType(); |
29 | 47 | llvm::SmallVector<mlir::Attribute, 4> outputValue; |
@@ -55,11 +73,6 @@ Rewrite positiveFloatInfinityLike(op: Op, type: Type) -> Op [{ |
55 | 73 | llvm::cast<mlir::ShapedType>(type), outputValue)); |
56 | 74 | }]; |
57 | 75 |
|
58 | | -Rewrite changeElementTypeToI1(type: Type) -> Type [{ |
59 | | - auto tensorType = llvm::cast<mlir::RankedTensorType>(type); |
60 | | - return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type()); |
61 | | -}]; |
62 | | - |
63 | 76 | // Nullary ops. |
64 | 77 | Pattern => |
65 | 78 | replace op<stablehlo.constant> {value = input: Attr<_: Tosa_Tensor>} |
@@ -142,10 +155,16 @@ Pattern => |
142 | 155 | replace op<stablehlo.minimum>(input0 : Value<_: Tosa_Tensor>, |
143 | 156 | input1 : Value<_: Tosa_Tensor>) |
144 | 157 | with op<tosa.minimum>(input0, input1); |
145 | | -Pattern => |
146 | | - replace op<stablehlo.multiply>(input0 : Value<_: Tosa_Tensor>, |
147 | | - input1 : Value<_: Tosa_Tensor>) |
148 | | - with op<tosa.mul>(input0, input1, zeroConst()); |
| 158 | +Pattern { |
| 159 | + let root = op<stablehlo.multiply>(input0 : Value<inputType: Tosa_Tensor>, |
| 160 | + input1 : Value<_: Tosa_Tensor>); |
| 161 | + rewrite root with { |
| 162 | + let typei8 = changeElementTypeToI8(inputType); |
| 163 | + let zeros = zerosLike(root, typei8); |
| 164 | + let mulResult = op<tosa.mul>(input0, input1, zeros) -> (inputType); |
| 165 | + replace root with mulResult; |
| 166 | + }; |
| 167 | +} |
149 | 168 | Pattern => |
150 | 169 | replace op<stablehlo.or>(input0 : Value<_: Tosa_Tensor>, |
151 | 170 | input1 : Value<_: Tosa_Tensor>) |
|
0 commit comments