diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td index cb6c1b63e4e4b..adcf6fac752fe 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -586,10 +586,10 @@ def IsSparseTensorSlicePred " ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isSlice()">; class SparseTensorOf allowedTypes> - : TensorOf; + : RankedTensorOf; class SparseTensorSliceOf allowedTypes> - : TensorOf; + : RankedTensorOf; class ScalarLikeOf allowedTypes> : AnyTypeOf<[0DTensorOf, AnyTypeOf], "scalar like">; diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 96a61419a541f..2c281c9f6aa85 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -92,8 +92,8 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]> { ``` }]; - let arguments = (ins Variadic>:$levels, - TensorOf<[AnyType]>:$values); + let arguments = (ins Variadic>:$levels, + RankedTensorOf<[AnyType]>:$values); let results = (outs AnySparseTensor: $result); let assemblyFormat = "` ` `(` $levels `)` `,` $values attr-dict `:`" @@ -138,12 +138,12 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria }]; let arguments = (ins AnySparseTensor:$tensor, - Variadic>:$out_levels, - TensorOf<[AnyType]>:$out_values); - let results = (outs Variadic>:$ret_levels, - TensorOf<[AnyType]>:$ret_values, - Variadic:$lvl_lens, - AnyIndexingScalarLike:$val_len); + Variadic>:$out_levels, + RankedTensorOf<[AnyType]>:$out_values); + let results = (outs Variadic>:$ret_levels, + RankedTensorOf<[AnyType]>:$ret_values, + Variadic:$lvl_lens, + AnyIndexingScalarLike:$val_len); let assemblyFormat = "$tensor attr-dict `:` type($tensor)" "`out_lvls` `(` $out_levels `:` type($out_levels) `)` " @@ -196,8 +196,8 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert", }]; - let arguments = (ins AnyTensor:$source); - let results = (outs AnyTensor:$dest); + let arguments = (ins AnyRankedTensor:$source); + let results = (outs AnyRankedTensor:$dest); let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; let extraClassDeclaration = [{ @@ -1447,7 +1447,7 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach", ]; let regions = (region SizedRegion<1>:$region); - let arguments = (ins AnyTensor:$tensor, + let arguments = (ins AnyRankedTensor:$tensor, Variadic:$initArgs, OptionalAttr:$order); let results = (outs Variadic:$results); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index b21bc1a93036c..7b1b1f383e634 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -1310,7 +1310,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, // The coordinates should be in shape of unsigned expCOORank = stt.getLvlRank() - cooStartLvl; if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) { - op->emitError("input/output trailing COO level-ranks don't match"); + return op->emitError("input/output trailing COO level-ranks don't match"); } } @@ -1350,7 +1350,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, } LogicalResult AssembleOp::verify() { - const auto valuesTp = getRankedTensorType(getValues()); + RankedTensorType valuesTp = getValues().getType(); const auto lvlsTp = getLevels().getTypes(); const auto resTp = getSparseTensorType(getResult()); return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp); @@ -1364,34 +1364,31 @@ LogicalResult DisassembleOp::verify() { if (ot.getType() != rt.getType()) return emitError("output levels and return levels type mismatch"); - const auto valuesTp = getRankedTensorType(getRetValues()); + RankedTensorType valuesTp = getRetValues().getType(); const auto lvlsTp = getRetLevels().getTypes(); const auto srcTp = getSparseTensorType(getTensor()); return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp); } LogicalResult ConvertOp::verify() { - if (auto tp1 = llvm::dyn_cast(getSource().getType())) { - if (auto tp2 = llvm::dyn_cast(getDest().getType())) { - if (tp1.getRank() != tp2.getRank()) - return emitError("unexpected conversion mismatch in rank"); - auto dstEnc = - llvm::dyn_cast_or_null(tp2.getEncoding()); - if (dstEnc && dstEnc.isSlice()) - return emitError("cannot convert to a sparse tensor slice"); - - auto shape1 = tp1.getShape(); - auto shape2 = tp2.getShape(); - // Accept size matches between the source and the destination type - // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or - // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). - for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) - if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) - return emitError("unexpected conversion mismatch in dimension ") << d; - return success(); - } - } - return emitError("unexpected type in convert"); + RankedTensorType tp1 = getSource().getType(); + RankedTensorType tp2 = getDest().getType(); + if (tp1.getRank() != tp2.getRank()) + return emitError("unexpected conversion mismatch in rank"); + auto dstEnc = + llvm::dyn_cast_or_null(tp2.getEncoding()); + if (dstEnc && dstEnc.isSlice()) + return emitError("cannot convert to a sparse tensor slice"); + + auto shape1 = tp1.getShape(); + auto shape2 = tp2.getShape(); + // Accept size matches between the source and the destination type + // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or + // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). + for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) + if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) + return emitError("unexpected conversion mismatch in dimension ") << d; + return success(); } OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { @@ -1495,7 +1492,8 @@ LogicalResult LvlOp::verify() { if (std::optional lvl = getConstantLvlIndex()) { auto stt = getSparseTensorType(getSource()); if (static_cast(lvl.value()) >= stt.getLvlRank()) - emitError("Level index exceeds the rank of the input sparse tensor"); + return emitError( + "Level index exceeds the rank of the input sparse tensor"); } return success(); } @@ -1697,14 +1695,14 @@ LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx, } LogicalResult ToSliceOffsetOp::verify() { - auto rank = getRankedTensorType(getSlice()).getRank(); + auto rank = getSlice().getType().getRank(); if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) return emitError("requested dimension out of bound"); return success(); } LogicalResult ToSliceStrideOp::verify() { - auto rank = getRankedTensorType(getSlice()).getRank(); + auto rank = getSlice().getType().getRank(); if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) return emitError("requested dimension out of bound"); return success(); @@ -1986,15 +1984,16 @@ LogicalResult ForeachOp::verify() { const auto iTp = IndexType::get(getContext()); for (Dimension d = 0; d < dimRank; d++) if (args[d].getType() != iTp) - emitError( + return emitError( llvm::formatv("Expecting Index type for argument at index {0}", d)); const auto elemTp = t.getElementType(); const auto valueTp = args[dimRank].getType(); if (elemTp != valueTp) - emitError(llvm::formatv("Unmatched element type between input tensor and " - "block argument, expected:{0}, got: {1}", - elemTp, valueTp)); + return emitError( + llvm::formatv("Unmatched element type between input tensor and " + "block argument, expected:{0}, got: {1}", + elemTp, valueTp)); return success(); } @@ -2011,15 +2010,15 @@ LogicalResult ReorderCOOOp::verify() { SparseTensorType dstStt = getSparseTensorType(getResultCoo()); if (!srcStt.isCOOType() || !dstStt.isCOOType()) - emitError("Expected COO sparse tensors only"); + return emitError("Expected COO sparse tensors only"); if (!srcStt.hasSameDimToLvl(dstStt)) - emitError("Unmatched dim2lvl map between input and result COO"); + return emitError("Unmatched dim2lvl map between input and result COO"); if (srcStt.getPosType() != dstStt.getPosType() || srcStt.getCrdType() != dstStt.getCrdType() || srcStt.getElementType() != dstStt.getElementType()) - emitError("Unmatched storage format between input and result COO"); + return emitError("Unmatched storage format between input and result COO"); return success(); } @@ -2044,10 +2043,11 @@ LogicalResult SortOp::verify() { AffineMap xPerm = getPermMap(); uint64_t nx = xPerm.getNumDims(); if (nx < 1) - emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx)); + return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx)); if (!xPerm.isPermutation()) - emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm)); + return emitError( + llvm::formatv("Expected a permutation map, got {0}", xPerm)); // We can't check the size of the buffers when n or buffer dimensions aren't // compile-time constants. @@ -2056,19 +2056,24 @@ LogicalResult SortOp::verify() { return success(); // Verify dimensions. - const auto checkDim = [&](Value v, Size minSize, const char *message) { + const auto checkDim = [&](Value v, Size minSize, + const char *message) -> LogicalResult { const Size sh = getMemRefType(v).getShape()[0]; if (!ShapedType::isDynamic(sh) && sh < minSize) - emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); + return emitError( + llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); + return success(); }; uint64_t n = cn.value(); uint64_t ny = 0; if (auto nyAttr = getNyAttr()) ny = nyAttr.getInt(); - checkDim(getXy(), n * (nx + ny), - "Expected dimension(xy) >= n * (rank(perm_map) + ny)"); + if (failed(checkDim(getXy(), n * (nx + ny), + "Expected dimension(xy) >= n * (rank(perm_map) + ny)"))) + return failure(); for (Value opnd : getYs()) - checkDim(opnd, n, "Expected dimension(y) >= n"); + if (failed(checkDim(opnd, n, "Expected dimension(y) >= n"))) + return failure(); return success(); } @@ -2101,8 +2106,8 @@ static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo, } if (lvlHi <= lvlLo) - parser.emitError(parser.getNameLoc(), - "expect larger level upper bound than lower bound"); + return parser.emitError(parser.getNameLoc(), + "expect larger level upper bound than lower bound"); return success(); } diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index 737b736ba795f..908d2d8aa83f7 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -105,7 +105,7 @@ func.func @invalid_positions_dense(%arg0: tensor<128xf64>) -> memref { func.func @invalid_positions_unranked(%arg0: tensor<*xf64>) -> memref { // expected-error@+1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}} - %0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<*xf64> to memref + %0 = "sparse_tensor.positions"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref) return %0 : memref } @@ -141,7 +141,7 @@ func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref { func.func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref { // expected-error@+1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}} - %0 = sparse_tensor.coordinates %arg0 { level = 0 : index } : tensor<*xf64> to memref + %0 = "sparse_tensor.coordinates"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref) return %0 : memref } @@ -347,7 +347,7 @@ func.func @sparse_wrong_arity_compression(%arg0: memref, // ----- func.func @sparse_convert_unranked(%arg0: tensor<*xf32>) -> tensor<10xf32> { - // expected-error@+1 {{unexpected type in convert}} + // expected-error@+1 {{invalid kind of type specified}} %0 = sparse_tensor.convert %arg0 : tensor<*xf32> to tensor<10xf32> return %0 : tensor<10xf32> }