@@ -455,6 +455,14 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
455455 if (!operand)
456456 return rewriter.notifyMatchFailure (
457457 subgroupOp, " warp result is not a xegpu::LoadNd op" );
458+ // Make sure the load op is the last operation in the warp op body. This
459+ // ensure that load op is not sinked earlier violating any barrier
460+ // synchronizations.
461+ auto yield = cast<gpu::YieldOp>(
462+ subgroupOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
463+ Operation *lastNode = yield->getPrevNode ();
464+ if (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
465+ return failure ();
458466
459467 auto loadOp = operand->get ().getDefiningOp <xegpu::LoadNdOp>();
460468 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType ();
@@ -782,6 +790,27 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
782790 }
783791};
784792
793+ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
794+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
795+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
796+ PatternRewriter &rewriter) const override {
797+ auto yield = cast<gpu::YieldOp>(
798+ subgroupOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
799+ Operation *lastNode = yield->getPrevNode ();
800+ // The last node must be a gpu::BarrierOp.
801+ auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
802+ if (!barrierOp)
803+ return failure ();
804+ // Simply move the barrier op outside of the warp op.
805+ rewriter.setInsertionPointAfter (subgroupOp);
806+ rewriter.create <gpu::BarrierOp>(
807+ barrierOp.getLoc (), barrierOp->getResultTypes (),
808+ barrierOp->getOperands (), barrierOp->getAttrs ());
809+ rewriter.eraseOp (barrierOp);
810+ return success ();
811+ }
812+ };
813+
785814} // namespace
786815
787816namespace {
@@ -797,6 +826,7 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
797826 patterns.add <CreateNdDescDistribution, StoreNdDistribution,
798827 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
799828 UpdateNdOffsetDistribution>(patterns.getContext ());
829+ patterns.add <GpuBarrierDistribution>(patterns.getContext (), 10 );
800830}
801831
802832void XeGPUSubgroupDistributePass::runOnOperation () {
0 commit comments