@@ -753,6 +753,11 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
753753 if (!lhsTy || !rhsTy || !resultTy)
754754 return {};
755755
756+ // Cannot create an ElementsAttr from non-int/float/index types
757+ if (!lhsTy.getElementType ().isIntOrIndexOrFloat () ||
758+ !rhsTy.getElementType ().isIntOrIndexOrFloat ())
759+ return {};
760+
756761 auto resultETy = resultTy.getElementType ();
757762 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1 ());
758763 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2 ());
@@ -791,6 +796,7 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
791796 if (lhsTy != rhsTy)
792797 return {};
793798
799+ // IntDivOp inputs must be integer type, no need to check for quantized type
794800 auto resultETy = resultTy.getElementType ();
795801 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1 ());
796802 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2 ());
@@ -888,6 +894,11 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
888894 if (!lhsTy || !rhsTy || !resultTy)
889895 return {};
890896
897+ // Cannot create an ElementsAttr from non-int/float/index types
898+ if (!lhsTy.getElementType ().isIntOrIndexOrFloat () ||
899+ !rhsTy.getElementType ().isIntOrIndexOrFloat ())
900+ return {};
901+
891902 auto resultETy = resultTy.getElementType ();
892903 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1 ());
893904 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2 ());
@@ -1098,6 +1109,10 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
10981109 return getResult ();
10991110 }
11001111
1112+ // Cannot create an ElementsAttr from non-int/float/index types
1113+ if (!inputTy.getElementType ().isIntOrIndexOrFloat ())
1114+ return {};
1115+
11011116 // reshape(const(x)) -> const(reshape-attr(x))
11021117 if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1 ())) {
11031118 // Constants must have static shape.
@@ -1233,13 +1248,12 @@ OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
12331248}
12341249
12351250OpFoldResult TransposeOp::fold (FoldAdaptor adaptor) {
1236- auto inputTy = llvm::cast<ShapedType>(getInput1 ().getType ());
12371251 auto resultTy = llvm::cast<ShapedType>(getType ());
12381252
12391253 // Transposing splat values just means reshaping.
12401254 if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1 ())) {
12411255 if (input.isSplat () && resultTy.hasStaticShape () &&
1242- inputTy .getElementType () == resultTy.getElementType ())
1256+ input. getType () .getElementType () == resultTy.getElementType ())
12431257 return input.reshape (resultTy);
12441258 }
12451259
0 commit comments