Skip to content

Commit 7775e3e

Browse files
authored
1 parent bde336e commit 7775e3e

File tree

5 files changed

+43
-21
lines changed

5 files changed

+43
-21
lines changed

WORKSPACE.bazel

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ workspace(name = "stablehlo")
1717

1818
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
1919

20-
LLVM_COMMIT = "aa65f93b71dee8cacb22be1957673c8be6a3ec24"
20+
LLVM_COMMIT = "956c0707d9098499a2682297b71f46b0a562eed9"
2121

22-
LLVM_SHA256 = "0a6046edb6a9834d5b912ec0e705dec91d39ee1b7b2fbb5930955d83d2090ff5"
22+
LLVM_SHA256 = "f90b866908daa3c65b74454943e52b59f40ab448f42a13b23e9823045f017066"
2323

2424
http_archive(
2525
name = "llvm-raw",

build_tools/llvm_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
5c24847e7dba01dde230e18b39a3074022279c89
1+
956c0707d9098499a2682297b71f46b0a562eed9

stablehlo/conversions/tosa/tests/unary.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ func.func @negate(%arg : tensor<10xf32>) -> tensor<10xf32> {
7979

8080
// CHECK-LABEL: @slice
8181
func.func @slice(%arg : tensor<4x3xf32>) -> tensor<2x2xf32> {
82-
// CHECK: tosa.slice %arg0 {size = array<i64: 2, 2>, start = array<i64: 2, 1>}
82+
// CHECK: %[[SIZE:.*]] = tosa.const_shape {value = dense<[2, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
83+
// CHECK: %[[START:.*]] = tosa.const_shape {value = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2>
84+
// CHECK: tosa.slice %arg0, %[[SIZE]], %[[START]]
8385
%0 = "stablehlo.slice"(%arg) {
8486
start_indices = array<i64: 2, 1>,
8587
limit_indices = array<i64: 4, 3>,

stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License.
2323
#include "mlir/Dialect/PDL/IR/PDL.h"
2424
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
2525
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
26+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
2627
#include "mlir/IR/Attributes.h"
2728
#include "mlir/IR/Block.h"
2829
#include "mlir/IR/BuiltinAttributes.h"
@@ -435,8 +436,8 @@ struct ConvertStablehloSliceOp : public OpRewritePattern<stablehlo::SliceOp> {
435436

436437
rewriter.replaceOpWithNewOp<tosa::SliceOp>(
437438
op, op.getType(), op.getOperand(),
438-
rewriter.getDenseI64ArrayAttr(startIndicesI64),
439-
rewriter.getDenseI64ArrayAttr(size));
439+
getTosaConstShape(rewriter, op.getLoc(), startIndicesI64),
440+
getTosaConstShape(rewriter, op.getLoc(), size));
440441
return success();
441442
}
442443
};

stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,33 @@
1515
#include "mlir/Dialect/Tosa/IR/TosaOps.td"
1616
#include "stablehlo/dialect/StablehloOps.td"
1717

18-
Rewrite zeroConst() -> Op [{
19-
auto type = rewriter.getI8Type();
20-
auto attr = mlir::DenseElementsAttr::get(
21-
llvm::cast<mlir::ShapedType>(type), rewriter.getZeroAttr(type));
18+
// Helper functions.
19+
Rewrite changeElementTypeToI1(type: Type) -> Type [{
20+
auto tensorType = llvm::cast<mlir::RankedTensorType>(type);
21+
return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type());
22+
}];
23+
24+
Rewrite changeElementTypeToI8(type: Type) -> Type [{
25+
auto tensorType = llvm::cast<mlir::RankedTensorType>(type);
26+
return RankedTensorType::get(tensorType.getShape(), rewriter.getI8Type());
27+
}];
28+
29+
Rewrite zerosLike(op: Op, type: Type) -> Op [{
30+
auto elementType = llvm::cast<mlir::TensorType>(type).getElementType();
31+
llvm::SmallVector<mlir::Attribute, 4> outputValue;
32+
33+
if (elementType.isF16() || elementType.isF32() || elementType.isBF16()) {
34+
outputValue.push_back(rewriter.getFloatAttr(elementType, 0));
35+
} else {
36+
outputValue.push_back(rewriter.getIntegerAttr(elementType, 0));
37+
}
38+
2239
return rewriter.create<mlir::tosa::ConstOp>(
23-
rewriter.getUnknownLoc(), type, attr);
40+
op->getLoc(), type,
41+
mlir::DenseElementsAttr::get(
42+
llvm::cast<mlir::ShapedType>(type), outputValue));
2443
}];
2544

26-
// Helper functions.
2745
Rewrite onesLike(op: Op, type: Type) -> Op [{
2846
auto elementType = llvm::cast<mlir::TensorType>(type).getElementType();
2947
llvm::SmallVector<mlir::Attribute, 4> outputValue;
@@ -55,11 +73,6 @@ Rewrite positiveFloatInfinityLike(op: Op, type: Type) -> Op [{
5573
llvm::cast<mlir::ShapedType>(type), outputValue));
5674
}];
5775

58-
Rewrite changeElementTypeToI1(type: Type) -> Type [{
59-
auto tensorType = llvm::cast<mlir::RankedTensorType>(type);
60-
return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type());
61-
}];
62-
6376
// Nullary ops.
6477
Pattern =>
6578
replace op<stablehlo.constant> {value = input: Attr<_: Tosa_Tensor>}
@@ -142,10 +155,16 @@ Pattern =>
142155
replace op<stablehlo.minimum>(input0 : Value<_: Tosa_Tensor>,
143156
input1 : Value<_: Tosa_Tensor>)
144157
with op<tosa.minimum>(input0, input1);
145-
Pattern =>
146-
replace op<stablehlo.multiply>(input0 : Value<_: Tosa_Tensor>,
147-
input1 : Value<_: Tosa_Tensor>)
148-
with op<tosa.mul>(input0, input1, zeroConst());
158+
Pattern {
159+
let root = op<stablehlo.multiply>(input0 : Value<inputType: Tosa_Tensor>,
160+
input1 : Value<_: Tosa_Tensor>);
161+
rewrite root with {
162+
let typei8 = changeElementTypeToI8(inputType);
163+
let zeros = zerosLike(root, typei8);
164+
let mulResult = op<tosa.mul>(input0, input1, zeros) -> (inputType);
165+
replace root with mulResult;
166+
};
167+
}
149168
Pattern =>
150169
replace op<stablehlo.or>(input0 : Value<_: Tosa_Tensor>,
151170
input1 : Value<_: Tosa_Tensor>)

0 commit comments

Comments
 (0)