@@ -1051,6 +1051,84 @@ LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
10511051 return success ();
10521052}
10531053
1054+ // ===----------------------------------------------------------------------===//
1055+ // TilingNoDpsOp
1056+ // ===----------------------------------------------------------------------===//
1057+
1058+ static Value getSlice (OpBuilder &builder, Location loc, Value source,
1059+ ArrayRef<OpFoldResult> offsets,
1060+ ArrayRef<OpFoldResult> sizes,
1061+ ArrayRef<OpFoldResult> strides) {
1062+ auto staticOffsets = getConstantIntValues (offsets);
1063+ auto staticSizes = getConstantIntValues (sizes);
1064+ auto staticStrides = getConstantIntValues (strides);
1065+
1066+ auto sourceShape = cast<ShapedType>(source.getType ()).getShape ();
1067+ if (staticSizes && ArrayRef (*staticSizes) == sourceShape)
1068+ return source;
1069+
1070+ return {mlir::tensor::ExtractSliceOp::create (builder, loc, source, offsets,
1071+ sizes, strides)};
1072+ }
1073+
1074+ static ShapedType getSliceType (ShapedType type, ArrayRef<OpFoldResult> sizes) {
1075+ auto staticSizes = getConstantIntValues (sizes);
1076+ if (staticSizes.has_value ())
1077+ return type.cloneWith (*staticSizes, type.getElementType ());
1078+ return nullptr ;
1079+ }
1080+
1081+ SmallVector<Range> TilingNoDpsOp::getIterationDomain (OpBuilder &builder) {
1082+ auto shape = cast<ShapedType>(getResult ().getType ()).getShape ();
1083+ auto zero = getAsIndexOpFoldResult (getContext (), 0 );
1084+ auto one = getAsIndexOpFoldResult (getContext (), 1 );
1085+ return llvm::map_to_vector (shape, [&](int64_t size) {
1086+ return Range{.offset = zero,
1087+ .size = getAsIndexOpFoldResult (getContext (), size),
1088+ .stride = one};
1089+ });
1090+ }
1091+
1092+ SmallVector<utils::IteratorType> TilingNoDpsOp::getLoopIteratorTypes () {
1093+ auto tensorType = cast<ShapedType>(getResult ().getType ());
1094+ SmallVector<utils::IteratorType> types (
1095+ static_cast <size_t >(tensorType.getRank ()), utils::IteratorType::parallel);
1096+ return types;
1097+ }
1098+
1099+ FailureOr<TilingResult>
1100+ TilingNoDpsOp::getTiledImplementation (OpBuilder &builder,
1101+ ArrayRef<OpFoldResult> offsets,
1102+ ArrayRef<OpFoldResult> sizes) {
1103+ auto loc = getLoc ();
1104+ auto strides = SmallVector<OpFoldResult>(
1105+ static_cast <size_t >(cast<ShapedType>(getOperand (0 ).getType ()).getRank ()),
1106+ getAsIndexOpFoldResult (getContext (), 1 ));
1107+ auto inputSlices = llvm::map_to_vector (getOperands (), [&](Value operand) {
1108+ return getSlice (builder, loc, operand, offsets, sizes, strides);
1109+ });
1110+ auto resultType =
1111+ getSliceType (cast<ShapedType>(getResult ().getType ()), sizes);
1112+ auto tiledOp = TilingNoDpsOp::create (builder, loc, TypeRange{resultType},
1113+ ValueRange (inputSlices));
1114+ return TilingResult{.tiledOps = {tiledOp},
1115+ .tiledValues = SmallVector<Value>{tiledOp.getResult ()},
1116+ .generatedSlices =
1117+ map_to_vector (inputSlices, [](Value val) {
1118+ return val.getDefiningOp ();
1119+ })};
1120+ }
1121+
1122+ LogicalResult TilingNoDpsOp::getResultTilePosition (
1123+ OpBuilder & builder, unsigned resultNumber,
1124+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
1125+ SmallVector<OpFoldResult> &resultOffsets,
1126+ SmallVector<OpFoldResult> &resultSizes) {
1127+ resultOffsets.assign (offsets.begin (), offsets.end ());
1128+ resultSizes.assign (sizes.begin (), sizes.end ());
1129+ return success ();
1130+ }
1131+
10541132// ===----------------------------------------------------------------------===//
10551133// OpWithShapedTypeInferTypeAdaptorInterfaceOp
10561134// ===----------------------------------------------------------------------===//
0 commit comments