@@ -135,19 +135,6 @@ static Value resolveDistributedTy(Value orig, T expected,
135135 return orig;
136136}
137137
138- // / Helper function to filter out the temporary layout attributes attached
139- // / during the layout assignment process. These are not needed after going to
140- // / SIMT.
141- static SmallVector<NamedAttribute>
142- removeTemporaryLayoutAttributes (ArrayRef<NamedAttribute> attrs) {
143- SmallVector<NamedAttribute> newAttrs;
144- for (NamedAttribute attr : attrs) {
145- if (!isa<xegpu::LayoutAttr>(attr.getValue ()))
146- newAttrs.push_back (attr);
147- }
148- return newAttrs;
149- }
150-
151138// / Helper function to check if the layout is packed. Layout is packed if it is
152139// / 2D and lane_data[0] != 1 (data packed from col dimension).
153140static bool hasPackedLayout (xegpu::LayoutAttr layout) {
@@ -197,9 +184,17 @@ struct MoveFuncBodyToWarpExecuteOnLane0
197184 return isa<gpu::WarpExecuteOnLane0Op>(op);
198185 }))
199186 return failure ();
200- // Create a new function with the same signature.
187+ // Create a new function with the same signature and same attributes.
188+ SmallVector<Type> workgroupAttributionsTypes =
189+ llvm::map_to_vector (gpuFuncOp.getWorkgroupAttributions (),
190+ [](BlockArgument arg) { return arg.getType (); });
191+ SmallVector<Type> privateAttributionsTypes =
192+ llvm::map_to_vector (gpuFuncOp.getPrivateAttributions (),
193+ [](BlockArgument arg) { return arg.getType (); });
201194 auto newGpuFunc = rewriter.create <gpu::GPUFuncOp>(
202- gpuFuncOp.getLoc (), gpuFuncOp.getName (), gpuFuncOp.getFunctionType ());
195+ gpuFuncOp.getLoc (), gpuFuncOp.getName (), gpuFuncOp.getFunctionType (),
196+ workgroupAttributionsTypes, privateAttributionsTypes);
197+ newGpuFunc->setAttrs (gpuFuncOp->getAttrs ());
203198 // Create a WarpExecuteOnLane0Op with same arguments and results as the
204199 // original gpuFuncOp.
205200 rewriter.setInsertionPointToEnd (&newGpuFunc.getFunctionBody ().front ());
@@ -265,13 +260,13 @@ struct MoveFuncBodyToWarpExecuteOnLane0
265260// / ```
266261struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
267262 using gpu::WarpDistributionPattern::WarpDistributionPattern;
268- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
263+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
269264 PatternRewriter &rewriter) const override {
270265 OpOperand *operand =
271- getWarpResult (subgroupOp , llvm::IsaPred<xegpu::CreateNdDescOp>);
266+ getWarpResult (warpOp , llvm::IsaPred<xegpu::CreateNdDescOp>);
272267 if (!operand)
273268 return rewriter.notifyMatchFailure (
274- subgroupOp , " warp result is not a xegpu::CreateNdDesc op" );
269+ warpOp , " warp result is not a xegpu::CreateNdDesc op" );
275270 auto descOp = operand->get ().getDefiningOp <xegpu::CreateNdDescOp>();
276271 unsigned operandIdx = operand->getOperandNumber ();
277272
@@ -288,9 +283,9 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
288283 newYieldValues.push_back (operand);
289284 newYieldTypes.push_back (operand.getType ());
290285 }
291- rewriter.setInsertionPoint (subgroupOp );
286+ rewriter.setInsertionPoint (warpOp );
292287 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
293- rewriter, subgroupOp , /* new yieled values = */ newYieldValues,
288+ rewriter, warpOp , /* new yieled values = */ newYieldValues,
294289 /* new yielded types = */ newYieldTypes, newRetIndices);
295290
296291 SmallVector<Value> newDescOperands;
@@ -347,10 +342,10 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
347342// / ```
348343struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
349344 using gpu::WarpDistributionPattern::WarpDistributionPattern;
350- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
345+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
351346 PatternRewriter &rewriter) const override {
352347 auto yield = cast<gpu::YieldOp>(
353- subgroupOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
348+ warpOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
354349 Operation *lastNode = yield->getPrevNode ();
355350 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
356351 if (!storeOp)
@@ -372,7 +367,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
372367
373368 SmallVector<size_t > newRetIndices;
374369 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
375- rewriter, subgroupOp ,
370+ rewriter, warpOp ,
376371 /* new yielded values = */
377372 ValueRange{storeOp.getValue (), storeOp.getTensorDesc ()},
378373 /* new yielded types = */
@@ -403,9 +398,9 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
403398 resolveDistributedTy (newWarpOp.getResult (newRetIndices[1 ]),
404399 distributedTensorDescTy, rewriter));
405400
406- rewriter.create <xegpu::StoreNdOp>(
407- newWarpOp.getLoc (), TypeRange{}, newStoreOperands,
408- removeTemporaryLayoutAttributes (storeOp-> getAttrs ()) );
401+ auto newStoreOp = rewriter.create <xegpu::StoreNdOp>(
402+ newWarpOp.getLoc (), TypeRange{}, newStoreOperands, storeOp-> getAttrs ());
403+ xegpu::removeLayoutAttrs (newStoreOp );
409404 rewriter.eraseOp (storeOp);
410405 return success ();
411406 }
@@ -449,21 +444,22 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
449444// / ```
450445struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
451446 using gpu::WarpDistributionPattern::WarpDistributionPattern;
452- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
447+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
453448 PatternRewriter &rewriter) const override {
454- OpOperand *operand =
455- getWarpResult (subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
449+ OpOperand *operand = getWarpResult (warpOp, [&](Operation *op) {
450+ if (!isa<xegpu::LoadNdOp>(op))
451+ return false ;
452+ // Make sure the same load op is the last operation in the warp op body.
453+ // This ensure that load op is not sinked earlier violating any barrier
454+ // synchronizations.
455+ auto yield = cast<gpu::YieldOp>(
456+ warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
457+ return yield->getPrevNode () == op;
458+ });
459+
456460 if (!operand)
457461 return rewriter.notifyMatchFailure (
458- subgroupOp, " warp result is not a xegpu::LoadNd op" );
459- // Make sure the load op is the last operation in the warp op body. This
460- // ensure that load op is not sinked earlier violating any barrier
461- // synchronizations.
462- auto yield = cast<gpu::YieldOp>(
463- subgroupOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
464- Operation *lastNode = yield->getPrevNode ();
465- if (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
466- return failure ();
462+ warpOp, " warp result is not a xegpu::LoadNd op" );
467463
468464 auto loadOp = operand->get ().getDefiningOp <xegpu::LoadNdOp>();
469465 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType ();
@@ -474,11 +470,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
474470
475471 unsigned operandIdx = operand->getOperandNumber ();
476472 VectorType distributedTypeByWarpOp =
477- cast<VectorType>(subgroupOp .getResult (operandIdx).getType ());
473+ cast<VectorType>(warpOp .getResult (operandIdx).getType ());
478474
479475 SmallVector<size_t > newRetIndices;
480476 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
481- rewriter, subgroupOp ,
477+ rewriter, warpOp ,
482478 /* new yielded values = */ loadOp.getTensorDesc (),
483479 /* new yielded types = */ tensorDescTy, newRetIndices);
484480
@@ -498,7 +494,8 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
498494 newWarpOp.getLoc (), loadNdDistValueTyOrFailure.value (),
499495 resolveDistributedTy (newWarpOp->getResult (newRetIndices[0 ]),
500496 distributedTensorDescTy, rewriter),
501- removeTemporaryLayoutAttributes (loadOp->getAttrs ()));
497+ loadOp->getAttrs ());
498+ xegpu::removeLayoutAttrs (newLoadOp);
502499 // Set the packed attribute if the layout requires it.
503500 newLoadOp.setPacked (hasPackedLayout (layout));
504501 Value distributedVal = newWarpOp.getResult (operandIdx);
@@ -548,12 +545,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
548545// / ```
549546struct DpasDistribution final : public gpu::WarpDistributionPattern {
550547 using gpu::WarpDistributionPattern::WarpDistributionPattern;
551- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
548+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
552549 PatternRewriter &rewriter) const override {
553- OpOperand *operand =
554- getWarpResult (subgroupOp, llvm::IsaPred<xegpu::DpasOp>);
550+ OpOperand *operand = getWarpResult (warpOp, llvm::IsaPred<xegpu::DpasOp>);
555551 if (!operand)
556- return rewriter.notifyMatchFailure (subgroupOp ,
552+ return rewriter.notifyMatchFailure (warpOp ,
557553 " warp result is not a xegpu::Dpas op" );
558554
559555 auto dpasOp = operand->get ().getDefiningOp <xegpu::DpasOp>();
@@ -599,7 +595,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
599595 // Create a new warp op without the dpas.
600596 SmallVector<size_t > newRetIndices;
601597 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
602- rewriter, subgroupOp , newYieldValues, newYieldTypes, newRetIndices);
598+ rewriter, warpOp , newYieldValues, newYieldTypes, newRetIndices);
603599
604600 FailureOr<VectorType> expectedDistLhsTyOrFailure =
605601 xegpu::getDistributedVectorType (dpasOp.getLhsType (), layoutA);
@@ -630,14 +626,16 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
630626 resolveDistributedTy (newWarpOp.getResult (newRetIndices[i]),
631627 newDpasOperandExpectedTypes[i], rewriter));
632628 }
633- Value newDpasOp = rewriter.create <xegpu::DpasOp>(
634- newWarpOp->getLoc (), distributedResultTy, newDpasOperands,
635- removeTemporaryLayoutAttributes (dpasOp->getAttrs ()));
629+ auto newDpasOp =
630+ rewriter.create <xegpu::DpasOp>(newWarpOp->getLoc (), distributedResultTy,
631+ newDpasOperands, dpasOp->getAttrs ());
632+ xegpu::removeLayoutAttrs (newDpasOp);
636633 Value distributedVal = newWarpOp.getResult (operandIdx);
637634 // Resolve the output type.
638- newDpasOp = resolveDistributedTy (
639- newDpasOp, distResultTypeByWarpOpOrFailure.value (), rewriter);
640- rewriter.replaceAllUsesWith (distributedVal, newDpasOp);
635+ Value typeResolved =
636+ resolveDistributedTy (newDpasOp.getResult (),
637+ distResultTypeByWarpOpOrFailure.value (), rewriter);
638+ rewriter.replaceAllUsesWith (distributedVal, typeResolved);
641639 return success ();
642640 }
643641};
@@ -678,13 +676,13 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
678676// / ```
679677struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
680678 using gpu::WarpDistributionPattern::WarpDistributionPattern;
681- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
679+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
682680 PatternRewriter &rewriter) const override {
683681 OpOperand *operand =
684- getWarpResult (subgroupOp , llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
682+ getWarpResult (warpOp , llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
685683 if (!operand)
686684 return rewriter.notifyMatchFailure (
687- subgroupOp , " warp result is not a xegpu::UpdateNdOffset op" );
685+ warpOp , " warp result is not a xegpu::UpdateNdOffset op" );
688686 auto updateOp = operand->get ().getDefiningOp <xegpu::UpdateNdOffsetOp>();
689687 unsigned operandIdx = operand->getOperandNumber ();
690688 // new update op does not have layout attribute.
@@ -703,7 +701,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
703701 }
704702 SmallVector<size_t > newRetIndices;
705703 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
706- rewriter, subgroupOp , newYieldValues, newYieldTypes, newRetIndices);
704+ rewriter, warpOp , newYieldValues, newYieldTypes, newRetIndices);
707705 rewriter.setInsertionPointAfter (newWarpOp);
708706 SmallVector<Value> newUpdateOperands;
709707 for (size_t i : newRetIndices) {
@@ -717,14 +715,15 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
717715 }
718716 }
719717 // Create a new update op outside the warp op.
720- Value newUpdateOp = rewriter.create <xegpu::UpdateNdOffsetOp>(
718+ auto newUpdateOp = rewriter.create <xegpu::UpdateNdOffsetOp>(
721719 newWarpOp.getLoc (), newTensorDescTy, newUpdateOperands,
722- removeTemporaryLayoutAttributes (updateOp->getAttrs ()));
720+ updateOp->getAttrs ());
721+ xegpu::removeLayoutAttrs (newUpdateOp);
723722 Value distributedVal = newWarpOp.getResult (operandIdx);
724723 // Resolve the distributed type with the original type.
725- newUpdateOp =
726- resolveDistributedTy ( newUpdateOp, distributedVal.getType (), rewriter);
727- rewriter.replaceAllUsesWith (distributedVal, newUpdateOp );
724+ Value typeResolved = resolveDistributedTy (
725+ newUpdateOp. getResult () , distributedVal.getType (), rewriter);
726+ rewriter.replaceAllUsesWith (distributedVal, typeResolved );
728727 return success ();
729728 }
730729};
@@ -758,10 +757,10 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
758757// / ```
759758struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
760759 using gpu::WarpDistributionPattern::WarpDistributionPattern;
761- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
760+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
762761 PatternRewriter &rewriter) const override {
763762 auto yield = cast<gpu::YieldOp>(
764- subgroupOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
763+ warpOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
765764 Operation *lastNode = yield->getPrevNode ();
766765 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
767766 if (!prefetchOp)
@@ -775,17 +774,18 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
775774 SmallVector<Type, 1 > newYieldTypes = {prefetchOp.getTensorDescType ()};
776775 SmallVector<size_t > newRetIndices;
777776 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
778- rewriter, subgroupOp , newYieldValues, newYieldTypes, newRetIndices);
777+ rewriter, warpOp , newYieldValues, newYieldTypes, newRetIndices);
779778 // Create a new prefetch op outside the warp op with updated tensor
780779 // descriptor type. Source tensor descriptor require type resolution.
781780 xegpu::TensorDescType newTensorDescTy =
782781 prefetchOp.getTensorDescType ().dropLayouts ();
783782 rewriter.setInsertionPointAfter (newWarpOp);
784783 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy (
785784 newWarpOp.getResult (newRetIndices[0 ]), newTensorDescTy, rewriter)};
786- rewriter.create <xegpu::PrefetchNdOp>(
787- newWarpOp.getLoc (), TypeRange{}, newPrefetchOperands,
788- removeTemporaryLayoutAttributes (prefetchOp->getAttrs ()));
785+ rewriter.create <xegpu::PrefetchNdOp>(newWarpOp.getLoc (), TypeRange{},
786+ newPrefetchOperands,
787+ prefetchOp->getAttrs ());
788+ xegpu::removeLayoutAttrs (prefetchOp);
789789 rewriter.eraseOp (prefetchOp);
790790 return success ();
791791 }
@@ -795,17 +795,17 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
795795// / region. This will simply move the barrier op outside of the warp op.
796796struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
797797 using gpu::WarpDistributionPattern::WarpDistributionPattern;
798- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
798+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
799799 PatternRewriter &rewriter) const override {
800800 auto yield = cast<gpu::YieldOp>(
801- subgroupOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
801+ warpOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
802802 Operation *lastNode = yield->getPrevNode ();
803803 // The last node must be a gpu::BarrierOp.
804804 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
805805 if (!barrierOp)
806806 return failure ();
807807 // Move the barrier op outside of the warp op.
808- rewriter.setInsertionPointAfter (subgroupOp );
808+ rewriter.setInsertionPointAfter (warpOp );
809809 rewriter.create <gpu::BarrierOp>(
810810 barrierOp.getLoc (), barrierOp->getResultTypes (),
811811 barrierOp->getOperands (), barrierOp->getAttrs ());
0 commit comments