@@ -67,20 +67,20 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
6767
6868// / Returns a memref.subview or a tensor.extract_slice based on the type of the
6969// / `source`.
70- static Value getSlice (OpBuilder &b, Location loc, Value source,
71- ArrayRef<OpFoldResult> offsets,
72- ArrayRef<OpFoldResult> sizes,
73- ArrayRef<OpFoldResult> strides) {
74- return TypeSwitch<Type, Value >(source.getType ())
75- .Case <RankedTensorType>([&](RankedTensorType t) -> Value {
70+ static Operation * getSlice (OpBuilder &b, Location loc, Value source,
71+ ArrayRef<OpFoldResult> offsets,
72+ ArrayRef<OpFoldResult> sizes,
73+ ArrayRef<OpFoldResult> strides) {
74+ return TypeSwitch<Type, Operation * >(source.getType ())
75+ .Case <RankedTensorType>([&](RankedTensorType t) -> Operation * {
7676 return b.create <tensor::ExtractSliceOp>(loc, source, offsets, sizes,
7777 strides);
7878 })
79- .Case <MemRefType>([&](MemRefType type) -> Value {
79+ .Case <MemRefType>([&](MemRefType type) -> Operation * {
8080 return b.create <memref::SubViewOp>(loc, source, offsets, sizes,
8181 strides);
8282 })
83- .Default ([&](Type t) { return nullptr ; });
83+ .Default ([&](Type t) -> Operation * { return nullptr ; });
8484}
8585
8686// ===----------------------------------------------------------------------===//
@@ -2634,18 +2634,29 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
26342634 auto oneAttr = builder.getI64IntegerAttr (1 );
26352635 SmallVector<OpFoldResult> strides (rank, oneAttr);
26362636 SmallVector<Value> tiledOperands;
2637- tiledOperands.emplace_back (
2638- getSlice (builder, getLoc (), getInput (), offsets, sizes, strides));
2639- tiledOperands.emplace_back (
2640- getSlice (builder, getLoc (), getOutput (), offsets, sizes, strides));
2637+ Operation *inputSlice =
2638+ getSlice (builder, getLoc (), getInput (), offsets, sizes, strides);
2639+ if (!inputSlice) {
2640+ return emitOpError (" failed to compute input slice" );
2641+ }
2642+ tiledOperands.emplace_back (inputSlice->getResult (0 ));
2643+ Operation *outputSlice =
2644+ getSlice (builder, getLoc (), getOutput (), offsets, sizes, strides);
2645+ if (!outputSlice) {
2646+ return emitOpError (" failed to compute output slice" );
2647+ }
2648+ tiledOperands.emplace_back (outputSlice->getResult (0 ));
26412649
26422650 SmallVector<Type, 4 > resultTypes;
26432651 if (hasPureTensorSemantics ())
26442652 resultTypes.push_back (tiledOperands[1 ].getType ());
26452653 Operation *tiledOp =
26462654 mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
26472655
2648- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
2656+ return TilingResult{
2657+ {tiledOp},
2658+ SmallVector<Value>(tiledOp->getResults ()),
2659+ llvm::to_vector (ArrayRef<Operation *>{inputSlice, outputSlice})};
26492660}
26502661
26512662LogicalResult SoftmaxOp::getResultTilePosition (
@@ -2992,8 +3003,9 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
29923003 int64_t filterRank = getFilterOperandRank ();
29933004 SmallVector<OpFoldResult> filterStrides (filterRank, oneAttr);
29943005 Location loc = getLoc ();
2995- tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
2996- loc, getFilter (), sliceOffsets, sliceSizes, filterStrides));
3006+ auto filterSlice = builder.create <tensor::ExtractSliceOp>(
3007+ loc, getFilter (), sliceOffsets, sliceSizes, filterStrides);
3008+ tiledOperands.emplace_back (filterSlice);
29973009
29983010 SmallVector<OpFoldResult> resultOffsets, resultSizes;
29993011 if (failed (getResultTilePosition (builder, 1 , offsets, sizes, resultOffsets,
@@ -3002,15 +3014,19 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
30023014
30033015 int64_t outputRank = getOutputOperandRank ();
30043016 SmallVector<OpFoldResult> outputStrides (outputRank, oneAttr);
3005- tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
3006- loc, getOutput (), resultOffsets, resultSizes, outputStrides));
3017+ auto outputSlice = builder.create <tensor::ExtractSliceOp>(
3018+ loc, getOutput (), resultOffsets, resultSizes, outputStrides);
3019+ tiledOperands.emplace_back (outputSlice);
30073020
30083021 SmallVector<Type> resultTypes;
30093022 resultTypes.push_back (tiledOperands[1 ].getType ());
30103023 Operation *tiledOp =
30113024 mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
30123025
3013- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
3026+ return TilingResult{
3027+ {tiledOp},
3028+ SmallVector<Value>(tiledOp->getResults ()),
3029+ llvm::to_vector (ArrayRef<Operation *>{filterSlice, outputSlice})};
30143030}
30153031
30163032// ===----------------------------------------------------------------------===//
@@ -3159,8 +3175,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
31593175 {sizes[getOutputNDim ()], sizeH, sizeW, sizes[getOutputCDim ()]});
31603176 int64_t inputRank = getInputOperandRank ();
31613177 SmallVector<OpFoldResult> inputStrides (inputRank, oneAttr);
3162- tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
3163- loc, getInput (), sliceOffsets, sliceSizes, inputStrides));
3178+ auto inputSlice = builder.create <tensor::ExtractSliceOp>(
3179+ loc, getInput (), sliceOffsets, sliceSizes, inputStrides);
3180+ tiledOperands.emplace_back (inputSlice);
31643181
31653182 SmallVector<OpFoldResult> resultOffsets, resultSizes;
31663183 if (failed (getResultTilePosition (builder, 1 , offsets, sizes, resultOffsets,
@@ -3169,15 +3186,19 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
31693186
31703187 int64_t outputRank = getOutputOperandRank ();
31713188 SmallVector<OpFoldResult> outputStrides (outputRank, oneAttr);
3172- tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
3173- loc, getOutput (), resultOffsets, resultSizes, outputStrides));
3189+ auto outputSlice = builder.create <tensor::ExtractSliceOp>(
3190+ loc, getOutput (), resultOffsets, resultSizes, outputStrides);
3191+ tiledOperands.emplace_back (outputSlice);
31743192
31753193 SmallVector<Type> resultTypes;
31763194 resultTypes.push_back (tiledOperands[1 ].getType ());
31773195 Operation *tiledOp =
31783196 mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
31793197
3180- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
3198+ return TilingResult{
3199+ {tiledOp},
3200+ SmallVector<Value>(tiledOp->getResults ()),
3201+ llvm::to_vector (ArrayRef<Operation *>{inputSlice, outputSlice})};
31813202}
31823203
31833204// ===----------------------------------------------------------------------===//
@@ -3321,8 +3342,9 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
33213342 sizes[getValueFDim ()]});
33223343 int64_t valueRank = getValueOperandRank ();
33233344 SmallVector<OpFoldResult> sliceStrides (valueRank, oneAttr);
3324- tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
3325- loc, getValue (), sliceOffsets, sliceSizes, sliceStrides));
3345+ auto valueSlice = builder.create <tensor::ExtractSliceOp>(
3346+ loc, getValue (), sliceOffsets, sliceSizes, sliceStrides);
3347+ tiledOperands.emplace_back (valueSlice);
33263348
33273349 SmallVector<OpFoldResult> resultOffsets, resultSizes;
33283350 if (failed (getResultTilePosition (builder, 1 , offsets, sizes, resultOffsets,
@@ -3331,15 +3353,19 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
33313353
33323354 int64_t outputRank = getOutputOperandRank ();
33333355 SmallVector<OpFoldResult> strides (outputRank, oneAttr);
3334- tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
3335- loc, getOutput (), resultOffsets, resultSizes, strides));
3356+ auto outputSlice = builder.create <tensor::ExtractSliceOp>(
3357+ loc, getOutput (), resultOffsets, resultSizes, strides);
3358+ tiledOperands.emplace_back (outputSlice);
33363359
33373360 SmallVector<Type> resultTypes;
33383361 resultTypes.push_back (tiledOperands[1 ].getType ());
33393362 Operation *tiledOp =
33403363 mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
33413364
3342- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
3365+ return TilingResult{
3366+ {tiledOp},
3367+ SmallVector<Value>(tiledOp->getResults ()),
3368+ llvm::to_vector (ArrayRef<Operation *>{valueSlice, outputSlice})};
33433369}
33443370
33453371// ===----------------------------------------------------------------------===//
0 commit comments