diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 4c2a7c36d8b5e..6700b4e0c2cb6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -33,6 +33,7 @@ #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/InterleavedRange.h" #include "llvm/Support/raw_ostream.h" #include #include @@ -95,6 +96,10 @@ static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) { } return true; } + +static std::string stringifyReassocIndices(ReassociationIndicesRef ri) { + return llvm::interleaved(ri, ", ", /*Prefix=*/"|", /*Suffix=*/""); +} #endif // NDEBUG /// Return the index of the first result of `map` that is a function of @@ -278,22 +283,21 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, highs, paddingValue, /*nofold=*/false); LLVM_DEBUG( - DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, - DBGS() << "insertPositions: "); - DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions, - DBGS() << "outerPositions: "); - DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), - DBGS() << "packedShape: "); + DBGSNL(); DBGSNL(); + DBGS() << "insertPositions: " + << llvm::interleaved(packingMetadata.insertPositions); + DBGSNL(); DBGS() << "outerPositions: " + << llvm::interleaved(packingMetadata.outerPositions); + DBGSNL(); DBGS() << "packedShape: " + << llvm::interleaved(packedTensorType.getShape()); + DBGSNL(); DBGS() << "packedToStripMinedShapePerm: " + << llvm::interleaved(packedToStripMinedShapePerm); DBGSNL(); - llvm::interleaveComma(packedToStripMinedShapePerm, - DBGS() << "packedToStripMinedShapePerm: "); - DBGSNL(); llvm::interleaveComma( - packingMetadata.reassociations, DBGS() << "reassociations: ", - [&](ReassociationIndices ri) { - llvm::interleaveComma(ri, llvm::dbgs() << "|"); - }); + DBGS() << "reassociations: " + << llvm::interleaved(llvm::map_range( + packingMetadata.reassociations, stringifyReassocIndices)); DBGSNL(); - llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); + DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape); DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) { @@ -343,7 +347,7 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); DBGS() << "reshape op: " << reshapeOp; DBGSNL(); - llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: "); + DBGS() << "transpPerm: " << llvm::interleaved(transpPerm); DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); // 7. Replace packOp by transposeOp. @@ -412,20 +416,19 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm); LLVM_DEBUG( - DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, - DBGS() << "insertPositions: "); - DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), - DBGS() << "packedShape: "); + DBGSNL(); DBGSNL(); + DBGS() << "insertPositions: " + << llvm::interleaved(packingMetadata.insertPositions); + DBGSNL(); DBGS() << "packedShape: " + << llvm::interleaved(packedTensorType.getShape()); + DBGSNL(); DBGS() << "packedToStripMinedShapePerm: " + << llvm::interleaved(packedToStripMinedShapePerm); DBGSNL(); - llvm::interleaveComma(packedToStripMinedShapePerm, - DBGS() << "packedToStripMinedShapePerm: "); - DBGSNL(); llvm::interleaveComma( - packingMetadata.reassociations, DBGS() << "reassociations: ", - [&](ReassociationIndices ri) { - llvm::interleaveComma(ri, llvm::dbgs() << "|"); - }); + DBGS() << "reassociations: " + << llvm::interleaved(llvm::map_range( + packingMetadata.reassociations, stringifyReassocIndices)); DBGSNL(); - llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); + DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape); DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); // 4. Collapse from the stripMinedShape to the padded result. @@ -488,10 +491,10 @@ FailureOr linalg::pack(RewriterBase &rewriter, SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); SmallVector iteratorTypes = linalgOp.getIteratorTypesArray(); - LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n"; - llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL(); - llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); - DBGSNL();); + LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n" + << "maps: " << llvm::interleaved(indexingMaps) << "\n" + << "iterators: " << llvm::interleaved(iteratorTypes) + << "\n"); SmallVector packOps; SmallVector unPackOps; @@ -515,18 +518,18 @@ FailureOr linalg::pack(RewriterBase &rewriter, LLVM_DEBUG( DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i] - << "\n"; - llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL(); - llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL(); - llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand, - DBGS() << "packedDimForEachOperand: "); - DBGSNL();); + << "\n" + << "maps: " << llvm::interleaved(indexingMaps) << "\n" + << "iterators: " << llvm::interleaved(iteratorTypes) << "\n" + << "packedDimForEachOperand: " + << llvm::interleaved(packedOperandsDims.packedDimForEachOperand) + << "\n"); } // Step 2. Propagate packing to all LinalgOp operands. SmallVector inputsAndInits, results; - SmallVector initOperands = llvm::to_vector(llvm::map_range( - linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); + SmallVector initOperands = + llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable())); SmallVector inputOperands = linalgOp.getDpsInputOperands(); for (const auto &operandsList : {inputOperands, initOperands}) { for (OpOperand *opOperand : operandsList) { @@ -536,11 +539,10 @@ FailureOr linalg::pack(RewriterBase &rewriter, listOfPackedOperandsDim.extractPackedDimsForOperand(pos); SmallVector innerPackSizes = listOfPackedOperandsDim.extractPackSizesForOperand(pos); - LLVM_DEBUG( - DBGS() << "operand: " << operand << "\n"; - llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL(); - llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: "); - DBGSNL();); + LLVM_DEBUG(DBGS() << "operand: " << operand << "\n" + << "innerPos: " << llvm::interleaved(innerPos) << "\n" + << "innerPackSizes: " + << llvm::interleaved(innerPackSizes) << "\n"); if (innerPackSizes.empty()) { inputsAndInits.push_back(operand); continue; @@ -835,7 +837,7 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, // not change the indexings of any operand. SmallVector permutation = computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos); - LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL();); + LLVM_DEBUG(DBGS() << "perm: " << llvm::interleaved(permutation) << "\n"); // Sign .. unsigned pollution. SmallVector unsignedPerm(permutation.begin(), permutation.end()); FailureOr interchangeResult = @@ -864,12 +866,12 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, // Add leading zeros to match numLoops, we only pack the last 3 dimensions // post interchange. - LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf, - DBGS() << "paddedSizesNextMultipleOf: "); - DBGSNL();); - LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ", - [](Range r) { llvm::dbgs() << r.size; }); - DBGSNL();); + LLVM_DEBUG(DBGS() << "paddedSizesNextMultipleOf: " + << llvm::interleaved(paddedSizesNextMultipleOf) << "\n" + << "loopRanges: " + << llvm::interleaved(llvm::map_range( + loopRanges, [](Range r) { return r.size; })) + << "\n"); SmallVector adjustedPackedSizes(numLoops - packedSizes.size(), rewriter.getIndexAttr(0)); for (int64_t i = 0, e = numPackedDims; i < e; ++i) { @@ -885,9 +887,8 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, {loopRanges[adjustedPackedSizes.size()].size, rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])})); } - LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes, - DBGS() << "adjustedPackedSizes: "); - DBGSNL();); + LLVM_DEBUG(DBGS() << "adjustedPackedSizes: " + << llvm::interleaved(adjustedPackedSizes) << "\n"); // TODO: If we wanted to give the genericOp a name after packing, after // calling `pack` would be a good time. One would still need to check that @@ -1202,9 +1203,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( srcPermForTranspose.append(SmallVector(packOp.getInnerDimsPos())); - LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"; - llvm::interleaveComma(srcPermForTranspose, DBGS() << "perm: "); - DBGSNL();); + LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n" + << "perm: " << llvm::interleaved(srcPermForTranspose) + << "\n"); // 2.1 Create tensor.empty (init value for TransposeOp) SmallVector transShapeForEmptyOp(srcRank - numTiles,