3434#include " llvm/ADT/ArrayRef.h"
3535#include " llvm/ADT/STLExtras.h"
3636#include " llvm/ADT/SmallVector.h"
37+ #include " llvm/ADT/SmallVectorExtras.h"
3738
3839namespace mlir {
3940namespace xegpu {
@@ -197,9 +198,17 @@ struct MoveFuncBodyToWarpExecuteOnLane0
197198 return isa<gpu::WarpExecuteOnLane0Op>(op);
198199 }))
199200 return failure ();
200- // Create a new function with the same signature.
201+ // Create a new function with the same signature and same attributes.
202+ SmallVector<Type> workgroupAttributionsTypes =
203+ llvm::map_to_vector (gpuFuncOp.getWorkgroupAttributions (),
204+ [](BlockArgument arg) { return arg.getType (); });
205+ SmallVector<Type> privateAttributionsTypes =
206+ llvm::map_to_vector (gpuFuncOp.getPrivateAttributions (),
207+ [](BlockArgument arg) { return arg.getType (); });
201208 auto newGpuFunc = rewriter.create <gpu::GPUFuncOp>(
202- gpuFuncOp.getLoc (), gpuFuncOp.getName (), gpuFuncOp.getFunctionType ());
209+ gpuFuncOp.getLoc (), gpuFuncOp.getName (), gpuFuncOp.getFunctionType (),
210+ workgroupAttributionsTypes, privateAttributionsTypes);
211+ newGpuFunc->setAttrs (gpuFuncOp->getAttrs ());
203212 // Create a WarpExecuteOnLane0Op with same arguments and results as the
204213 // original gpuFuncOp.
205214 rewriter.setInsertionPointToEnd (&newGpuFunc.getFunctionBody ().front ());
@@ -265,13 +274,13 @@ struct MoveFuncBodyToWarpExecuteOnLane0
265274// / ```
266275struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
267276 using gpu::WarpDistributionPattern::WarpDistributionPattern;
268- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
277+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
269278 PatternRewriter &rewriter) const override {
270279 OpOperand *operand =
271- getWarpResult (subgroupOp , llvm::IsaPred<xegpu::CreateNdDescOp>);
280+ getWarpResult (warpOp , llvm::IsaPred<xegpu::CreateNdDescOp>);
272281 if (!operand)
273282 return rewriter.notifyMatchFailure (
274- subgroupOp , " warp result is not a xegpu::CreateNdDesc op" );
283+ warpOp , " warp result is not a xegpu::CreateNdDesc op" );
275284 auto descOp = operand->get ().getDefiningOp <xegpu::CreateNdDescOp>();
276285 unsigned operandIdx = operand->getOperandNumber ();
277286
@@ -288,9 +297,9 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
288297 newYieldValues.push_back (operand);
289298 newYieldTypes.push_back (operand.getType ());
290299 }
291- rewriter.setInsertionPoint (subgroupOp );
300+ rewriter.setInsertionPoint (warpOp );
292301 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
293- rewriter, subgroupOp , /* new yieled values = */ newYieldValues,
302+ rewriter, warpOp , /* new yieled values = */ newYieldValues,
294303 /* new yielded types = */ newYieldTypes, newRetIndices);
295304
296305 SmallVector<Value> newDescOperands;
@@ -347,10 +356,10 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
347356// / ```
348357struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
349358 using gpu::WarpDistributionPattern::WarpDistributionPattern;
350- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
359+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
351360 PatternRewriter &rewriter) const override {
352361 auto yield = cast<gpu::YieldOp>(
353- subgroupOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
362+ warpOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
354363 Operation *lastNode = yield->getPrevNode ();
355364 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
356365 if (!storeOp)
@@ -372,7 +381,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
372381
373382 SmallVector<size_t > newRetIndices;
374383 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
375- rewriter, subgroupOp ,
384+ rewriter, warpOp ,
376385 /* new yielded values = */
377386 ValueRange{storeOp.getValue (), storeOp.getTensorDesc ()},
378387 /* new yielded types = */
@@ -449,21 +458,22 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
449458// / ```
450459struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
451460 using gpu::WarpDistributionPattern::WarpDistributionPattern;
452- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
461+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
453462 PatternRewriter &rewriter) const override {
454- OpOperand *operand =
455- getWarpResult (subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
463+ OpOperand *operand = getWarpResult (warpOp, [&](Operation *op) {
464+ if (!isa<xegpu::LoadNdOp>(op))
465+ return false ;
466+ // Make sure the same load op is the last operation in the warp op body.
467+ // This ensure that load op is not sinked earlier violating any barrier
468+ // synchronizations.
469+ auto yield = cast<gpu::YieldOp>(
470+ warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
471+ return yield->getPrevNode () == op;
472+ });
473+
456474 if (!operand)
457475 return rewriter.notifyMatchFailure (
458- subgroupOp, " warp result is not a xegpu::LoadNd op" );
459- // Make sure the load op is the last operation in the warp op body. This
460- // ensure that load op is not sinked earlier violating any barrier
461- // synchronizations.
462- auto yield = cast<gpu::YieldOp>(
463- subgroupOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
464- Operation *lastNode = yield->getPrevNode ();
465- if (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
466- return failure ();
476+ warpOp, " warp result is not a xegpu::LoadNd op" );
467477
468478 auto loadOp = operand->get ().getDefiningOp <xegpu::LoadNdOp>();
469479 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType ();
@@ -474,11 +484,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
474484
475485 unsigned operandIdx = operand->getOperandNumber ();
476486 VectorType distributedTypeByWarpOp =
477- cast<VectorType>(subgroupOp .getResult (operandIdx).getType ());
487+ cast<VectorType>(warpOp .getResult (operandIdx).getType ());
478488
479489 SmallVector<size_t > newRetIndices;
480490 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
481- rewriter, subgroupOp ,
491+ rewriter, warpOp ,
482492 /* new yielded values = */ loadOp.getTensorDesc (),
483493 /* new yielded types = */ tensorDescTy, newRetIndices);
484494
@@ -548,12 +558,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
548558// / ```
549559struct DpasDistribution final : public gpu::WarpDistributionPattern {
550560 using gpu::WarpDistributionPattern::WarpDistributionPattern;
551- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
561+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
552562 PatternRewriter &rewriter) const override {
553- OpOperand *operand =
554- getWarpResult (subgroupOp, llvm::IsaPred<xegpu::DpasOp>);
563+ OpOperand *operand = getWarpResult (warpOp, llvm::IsaPred<xegpu::DpasOp>);
555564 if (!operand)
556- return rewriter.notifyMatchFailure (subgroupOp ,
565+ return rewriter.notifyMatchFailure (warpOp ,
557566 " warp result is not a xegpu::Dpas op" );
558567
559568 auto dpasOp = operand->get ().getDefiningOp <xegpu::DpasOp>();
@@ -599,7 +608,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
599608 // Create a new warp op without the dpas.
600609 SmallVector<size_t > newRetIndices;
601610 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
602- rewriter, subgroupOp , newYieldValues, newYieldTypes, newRetIndices);
611+ rewriter, warpOp , newYieldValues, newYieldTypes, newRetIndices);
603612
604613 FailureOr<VectorType> expectedDistLhsTyOrFailure =
605614 xegpu::getDistributedVectorType (dpasOp.getLhsType (), layoutA);
@@ -678,13 +687,13 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
678687// / ```
679688struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
680689 using gpu::WarpDistributionPattern::WarpDistributionPattern;
681- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
690+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
682691 PatternRewriter &rewriter) const override {
683692 OpOperand *operand =
684- getWarpResult (subgroupOp , llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
693+ getWarpResult (warpOp , llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
685694 if (!operand)
686695 return rewriter.notifyMatchFailure (
687- subgroupOp , " warp result is not a xegpu::UpdateNdOffset op" );
696+ warpOp , " warp result is not a xegpu::UpdateNdOffset op" );
688697 auto updateOp = operand->get ().getDefiningOp <xegpu::UpdateNdOffsetOp>();
689698 unsigned operandIdx = operand->getOperandNumber ();
690699 // new update op does not have layout attribute.
@@ -703,7 +712,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
703712 }
704713 SmallVector<size_t > newRetIndices;
705714 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
706- rewriter, subgroupOp , newYieldValues, newYieldTypes, newRetIndices);
715+ rewriter, warpOp , newYieldValues, newYieldTypes, newRetIndices);
707716 rewriter.setInsertionPointAfter (newWarpOp);
708717 SmallVector<Value> newUpdateOperands;
709718 for (size_t i : newRetIndices) {
@@ -758,10 +767,10 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
758767// / ```
759768struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
760769 using gpu::WarpDistributionPattern::WarpDistributionPattern;
761- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
770+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
762771 PatternRewriter &rewriter) const override {
763772 auto yield = cast<gpu::YieldOp>(
764- subgroupOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
773+ warpOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
765774 Operation *lastNode = yield->getPrevNode ();
766775 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
767776 if (!prefetchOp)
@@ -775,7 +784,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
775784 SmallVector<Type, 1 > newYieldTypes = {prefetchOp.getTensorDescType ()};
776785 SmallVector<size_t > newRetIndices;
777786 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
778- rewriter, subgroupOp , newYieldValues, newYieldTypes, newRetIndices);
787+ rewriter, warpOp , newYieldValues, newYieldTypes, newRetIndices);
779788 // Create a new prefetch op outside the warp op with updated tensor
780789 // descriptor type. Source tensor descriptor require type resolution.
781790 xegpu::TensorDescType newTensorDescTy =
@@ -795,17 +804,17 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
795804// / region. This will simply move the barrier op outside of the warp op.
796805struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
797806 using gpu::WarpDistributionPattern::WarpDistributionPattern;
798- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
807+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
799808 PatternRewriter &rewriter) const override {
800809 auto yield = cast<gpu::YieldOp>(
801- subgroupOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
810+ warpOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
802811 Operation *lastNode = yield->getPrevNode ();
803812 // The last node must be a gpu::BarrierOp.
804813 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
805814 if (!barrierOp)
806815 return failure ();
807816 // Move the barrier op outside of the warp op.
808- rewriter.setInsertionPointAfter (subgroupOp );
817+ rewriter.setInsertionPointAfter (warpOp );
809818 rewriter.create <gpu::BarrierOp>(
810819 barrierOp.getLoc (), barrierOp->getResultTypes (),
811820 barrierOp->getOperands (), barrierOp->getAttrs ());
0 commit comments