@@ -2648,39 +2648,56 @@ SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
26482648 return iteratorTypes;
26492649}
26502650
2651- FailureOr<TilingResult>
2652- SoftmaxOp::getTiledImplementation ( OpBuilder &builder,
2653- ArrayRef<OpFoldResult> offsets,
2654- ArrayRef<OpFoldResult> sizes) {
2655- int64_t rank = getInputOperandRank ();
2651+ static FailureOr<TilingResult>
2652+ implementTiledSoftMax (SoftmaxOp &op, OpBuilder &builder,
2653+ ArrayRef<OpFoldResult> offsets,
2654+ ArrayRef<OpFoldResult> sizes) {
2655+ int64_t rank = op. getInputOperandRank ();
26562656 auto oneAttr = builder.getI64IntegerAttr (1 );
26572657 SmallVector<OpFoldResult> strides (rank, oneAttr);
26582658 SmallVector<Value> tiledOperands;
26592659 Operation *inputSlice =
2660- getSlice (builder, getLoc (), getInput (), offsets, sizes, strides);
2660+ getSlice (builder, op. getLoc (), op. getInput (), offsets, sizes, strides);
26612661 if (!inputSlice) {
2662- return emitOpError (" failed to compute input slice" );
2662+ return op. emitOpError (" failed to compute input slice" );
26632663 }
26642664 tiledOperands.emplace_back (inputSlice->getResult (0 ));
26652665 Operation *outputSlice =
2666- getSlice (builder, getLoc (), getOutput (), offsets, sizes, strides);
2666+ getSlice (builder, op. getLoc (), op. getOutput (), offsets, sizes, strides);
26672667 if (!outputSlice) {
2668- return emitOpError (" failed to compute output slice" );
2668+ return op. emitOpError (" failed to compute output slice" );
26692669 }
26702670 tiledOperands.emplace_back (outputSlice->getResult (0 ));
26712671
26722672 SmallVector<Type, 4 > resultTypes;
2673- if (hasPureTensorSemantics ())
2673+ if (op. hasPureTensorSemantics ())
26742674 resultTypes.push_back (tiledOperands[1 ].getType ());
26752675 Operation *tiledOp =
2676- mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
2676+ mlir::clone (builder, op. getOperation (), resultTypes, tiledOperands);
26772677
26782678 return TilingResult{
26792679 {tiledOp},
26802680 SmallVector<Value>(tiledOp->getResults ()),
26812681 llvm::to_vector (ArrayRef<Operation *>{inputSlice, outputSlice})};
26822682}
26832683
2684+ FailureOr<TilingResult>
2685+ SoftmaxOp::getTiledImplementation (OpBuilder &builder,
2686+ ArrayRef<OpFoldResult> offsets,
2687+ ArrayRef<OpFoldResult> sizes) {
2688+ return implementTiledSoftMax (*this , builder, offsets, sizes);
2689+ }
2690+
2691+ FailureOr<TilingResult>
2692+ SoftmaxOp::generateResultTileValue (OpBuilder &builder, unsigned resultNumber,
2693+ ArrayRef<OpFoldResult> offsets,
2694+ ArrayRef<OpFoldResult> sizes) {
2695+ if (resultNumber != 0 )
2696+ return failure ();
2697+
2698+ return implementTiledSoftMax (*this , builder, offsets, sizes);
2699+ }
2700+
26842701LogicalResult SoftmaxOp::getResultTilePosition (
26852702 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
26862703 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
0 commit comments