@@ -300,10 +300,11 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
300300
301301static bool isGenericOutsNotUsed (linalg::GenericOp genericOp) {
302302 int numDpsOuts = genericOp.getNumDpsInits ();
303+ Block *block = genericOp.getBody ();
304+ int numBlockArgs = block->getNumArguments ();
305+ int initArgStartIndex = numBlockArgs - numDpsOuts;
303306 for (int i = 0 ; i < numDpsOuts; ++i) {
304- Block *block = genericOp.getBody ();
305- int numBlockArgs = block->getNumArguments ();
306- int matchingInitArgIndex = numBlockArgs - numDpsOuts + i;
307+ int matchingInitArgIndex = initArgStartIndex + i;
307308 return block->getArgument (matchingInitArgIndex).use_empty ();
308309 }
309310 return true ;
@@ -312,18 +313,13 @@ static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
312313// / Pack a genericOp and return it.
313314static GenericOp packGenericOp (RewriterBase &rewriter, GenericOp genericOp,
314315 Value dest, AffineMap packedOutIndexingMap,
315- const PackInfo &packInfo) {
316+ const PackInfo &packInfo,
317+ bool canUnpackPackFold) {
316318 Location loc = genericOp.getLoc ();
317319 SmallVector<Value> inputOperands;
318320 SmallVector<Value> inputOperandsFromUnpackedSource;
319321 SmallVector<AffineMap> indexingMaps;
320322
321- // Note: canUnpackPackFold needs to also guarantee the generic body
322- // doesn't have gather semantics. Since such scenarios has been
323- // rejected by both BubbleUpPackOpThroughGenericOp and
324- // PushDownUnPackOpThroughGenericOp, we can safely assume
325- // canUnpackPackFold is as long as init is not used.
326- bool canUnpackPackFold = isGenericOutsNotUsed (genericOp);
327323 for (OpOperand *inputOperand : genericOp.getDpsInputOperands ()) {
328324 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand (
329325 rewriter, loc, packInfo, genericOp, inputOperand);
@@ -338,10 +334,18 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
338334 indexingMaps.push_back (packedIndexingMap);
339335 }
340336
337+ // Note: Whether or not the unpack pack sequence can fold also depends on
338+ // the caller of this routine.
339+ // 1) In push down unpack op pattern, this is true because the pack op is
340+ // generated and we can guarantee they are compatible.
341+ // 2) In bubble up pack op pattern, this is not true because the unpack op
342+ // can be from an arbitrary domain so we need to keep both.
343+ canUnpackPackFold = canUnpackPackFold && isGenericOutsNotUsed (genericOp) &&
344+ !hasGatherSemantics (genericOp);
341345 // If The pack and unpack op can be folded:
342- // 1) use unpack op source op for operand to fold unpack -> pack sequence
343- // 2) init tensor of the generic op can be replaced by the new tensor.empty
344- // as the generic out .
346+ // 1) use unpack op source op for operand to fold unpack -> pack sequence.
347+ // 2) init tensor of the generic op can be replaced by the destination of the
348+ // pack op .
345349 if (canUnpackPackFold) {
346350 inputOperands = inputOperandsFromUnpackedSource;
347351 if (auto destPack = dest.getDefiningOp <linalg::PackOp>())
@@ -484,7 +488,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
484488 dest = packOpDest;
485489 }
486490 return packGenericOp (rewriter, genericOp, dest, packedOutIndexingMap,
487- *packInfo);
491+ *packInfo, /* canUnpackPackFold= */ false );
488492}
489493
490494// / Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
@@ -1122,7 +1126,8 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
11221126
11231127 // Pack the genericOp.
11241128 GenericOp newGenericOp =
1125- packGenericOp (rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
1129+ packGenericOp (rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1130+ /* canUnpackPackFold=*/ true );
11261131 Value newResult =
11271132 newGenericOp.getTiedOpResult (newGenericOp.getDpsInitOperand (0 ));
11281133
0 commit comments