@@ -455,6 +455,14 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
455455 if (!operand)
456456 return rewriter.notifyMatchFailure (
457457 subgroupOp, " warp result is not a xegpu::LoadNd op" );
458+ // Make sure the load op is the last operation in the warp op body. This
459+ // ensure that load op is not sinked earlier violating any barrier
460+ // synchronizations.
461+ auto yield = cast<gpu::YieldOp>(
462+ subgroupOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
463+ Operation *lastNode = yield->getPrevNode ();
464+ if (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
465+ return failure ();
458466
459467 auto loadOp = operand->get ().getDefiningOp <xegpu::LoadNdOp>();
460468 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType ();
@@ -782,6 +790,29 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
782790 }
783791};
784792
793+ // / Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0`
794+ // / region. This will simply move the barrier op outside of the warp op.
795+ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
796+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
797+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
798+ PatternRewriter &rewriter) const override {
799+ auto yield = cast<gpu::YieldOp>(
800+ subgroupOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
801+ Operation *lastNode = yield->getPrevNode ();
802+ // The last node must be a gpu::BarrierOp.
803+ auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
804+ if (!barrierOp)
805+ return failure ();
806+ // Move the barrier op outside of the warp op.
807+ rewriter.setInsertionPointAfter (subgroupOp);
808+ rewriter.create <gpu::BarrierOp>(
809+ barrierOp.getLoc (), barrierOp->getResultTypes (),
810+ barrierOp->getOperands (), barrierOp->getAttrs ());
811+ rewriter.eraseOp (barrierOp);
812+ return success ();
813+ }
814+ };
815+
785816} // namespace
786817
787818namespace {
@@ -796,7 +827,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
796827 RewritePatternSet &patterns) {
797828 patterns.add <CreateNdDescDistribution, StoreNdDistribution,
798829 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
799- UpdateNdOffsetDistribution>(patterns.getContext ());
830+ UpdateNdOffsetDistribution, GpuBarrierDistribution>(
831+ patterns.getContext ());
800832}
801833
802834void XeGPUSubgroupDistributePass::runOnOperation () {
0 commit comments