Skip to content

Commit d266b33

Browse files
committed
[AutoBump] Merge with fixes of 705f858 (Jun 13)
2 parents 1f6dc6e + 705f858 commit d266b33

File tree

3 files changed

+258
-12
lines changed

3 files changed

+258
-12
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ class MulOperandsAndResultElementType
7878
return success();
7979
}
8080

81-
return failure();
81+
// In cases of all other types, op requires the same element
82+
// type for all operands and result.
83+
return impl::verifySameOperandsAndResultElementType(op);
8284
}
8385
};
8486

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,11 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
753753
if (!lhsTy || !rhsTy || !resultTy)
754754
return {};
755755

756+
// Cannot create an ElementsAttr from non-int/float/index types
757+
if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
758+
!rhsTy.getElementType().isIntOrIndexOrFloat())
759+
return {};
760+
756761
auto resultETy = resultTy.getElementType();
757762
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
758763
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
@@ -791,6 +796,7 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
791796
if (lhsTy != rhsTy)
792797
return {};
793798

799+
// IntDivOp inputs must be integer type, no need to check for quantized type
794800
auto resultETy = resultTy.getElementType();
795801
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
796802
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
@@ -888,6 +894,11 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
888894
if (!lhsTy || !rhsTy || !resultTy)
889895
return {};
890896

897+
// Cannot create an ElementsAttr from non-int/float/index types
898+
if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
899+
!rhsTy.getElementType().isIntOrIndexOrFloat())
900+
return {};
901+
891902
auto resultETy = resultTy.getElementType();
892903
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
893904
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
@@ -1098,6 +1109,10 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
10981109
return getResult();
10991110
}
11001111

1112+
// Cannot create an ElementsAttr from non-int/float/index types
1113+
if (!inputTy.getElementType().isIntOrIndexOrFloat())
1114+
return {};
1115+
11011116
// reshape(const(x)) -> const(reshape-attr(x))
11021117
if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
11031118
// Constants must have static shape.
@@ -1233,13 +1248,12 @@ OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
12331248
}
12341249

12351250
OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1236-
auto inputTy = llvm::cast<ShapedType>(getInput1().getType());
12371251
auto resultTy = llvm::cast<ShapedType>(getType());
12381252

12391253
// Transposing splat values just means reshaping.
12401254
if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
12411255
if (input.isSplat() && resultTy.hasStaticShape() &&
1242-
inputTy.getElementType() == resultTy.getElementType())
1256+
input.getType().getElementType() == resultTy.getElementType())
12431257
return input.reshape(resultTy);
12441258
}
12451259

0 commit comments

Comments
 (0)