@@ -455,6 +455,14 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
455455    if  (!operand)
456456      return  rewriter.notifyMatchFailure (
457457          subgroupOp, " warp result is not a xegpu::LoadNd op"  );
458+     //  Make sure the load op is the last operation in the warp op body. This
459+     //  ensure that load op is not sinked earlier violating any barrier
460+     //  synchronizations.
461+     auto  yield = cast<gpu::YieldOp>(
462+         subgroupOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
463+     Operation *lastNode = yield->getPrevNode ();
464+     if  (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
465+       return  failure ();
458466
459467    auto  loadOp = operand->get ().getDefiningOp <xegpu::LoadNdOp>();
460468    xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType ();
@@ -782,6 +790,29 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
782790  }
783791};
784792
793+ // / Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0`
794+ // / region. This will simply move the barrier op outside of the warp op.
795+ struct  GpuBarrierDistribution  final  : public gpu::WarpDistributionPattern {
796+   using  gpu::WarpDistributionPattern::WarpDistributionPattern;
797+   LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
798+                                 PatternRewriter &rewriter) const  override  {
799+     auto  yield = cast<gpu::YieldOp>(
800+         subgroupOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
801+     Operation *lastNode = yield->getPrevNode ();
802+     //  The last node must be a gpu::BarrierOp.
803+     auto  barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
804+     if  (!barrierOp)
805+       return  failure ();
806+     //  Move the barrier op outside of the warp op.
807+     rewriter.setInsertionPointAfter (subgroupOp);
808+     rewriter.create <gpu::BarrierOp>(
809+         barrierOp.getLoc (), barrierOp->getResultTypes (),
810+         barrierOp->getOperands (), barrierOp->getAttrs ());
811+     rewriter.eraseOp (barrierOp);
812+     return  success ();
813+   }
814+ };
815+ 
785816} //  namespace
786817
787818namespace  {
@@ -796,7 +827,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
796827    RewritePatternSet &patterns) {
797828  patterns.add <CreateNdDescDistribution, StoreNdDistribution,
798829               LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
799-                UpdateNdOffsetDistribution>(patterns.getContext ());
830+                UpdateNdOffsetDistribution, GpuBarrierDistribution>(
831+       patterns.getContext ());
800832}
801833
802834void  XeGPUSubgroupDistributePass::runOnOperation () {
0 commit comments