Skip to content

Commit e014914

Browse files
committed
fix
1 parent 98e8ef2 commit e014914

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,14 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
455455
if (!operand)
456456
return rewriter.notifyMatchFailure(
457457
subgroupOp, "warp result is not a xegpu::LoadNd op");
458+
// Make sure the load op is the last operation in the warp op body. This
459+
// ensure that load op is not sinked earlier violating any barrier
460+
// synchronizations.
461+
auto yield = cast<gpu::YieldOp>(
462+
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
463+
Operation *lastNode = yield->getPrevNode();
464+
if (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
465+
return failure();
458466

459467
auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
460468
xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
@@ -782,6 +790,27 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
782790
}
783791
};
784792

793+
struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
794+
using gpu::WarpDistributionPattern::WarpDistributionPattern;
795+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
796+
PatternRewriter &rewriter) const override {
797+
auto yield = cast<gpu::YieldOp>(
798+
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
799+
Operation *lastNode = yield->getPrevNode();
800+
// The last node must be a gpu::BarrierOp.
801+
auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
802+
if (!barrierOp)
803+
return failure();
804+
// Simply move the barrier op outside of the warp op.
805+
rewriter.setInsertionPointAfter(subgroupOp);
806+
rewriter.create<gpu::BarrierOp>(
807+
barrierOp.getLoc(), barrierOp->getResultTypes(),
808+
barrierOp->getOperands(), barrierOp->getAttrs());
809+
rewriter.eraseOp(barrierOp);
810+
return success();
811+
}
812+
};
813+
785814
} // namespace
786815

787816
namespace {
@@ -797,6 +826,7 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
797826
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
798827
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
799828
UpdateNdOffsetDistribution>(patterns.getContext());
829+
patterns.add<GpuBarrierDistribution>(patterns.getContext(), 10);
800830
}
801831

802832
void XeGPUSubgroupDistributePass::runOnOperation() {

mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,22 @@ gpu.module @test {
278278
gpu.return
279279
}
280280
}
281+
282+
// -----
283+
// CHECK-LABEL: gpu.func @gpu_barrier({{.*}}) {
284+
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
285+
// CHECK-NEXT: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16xf16> -> vector<1xf16>
286+
// CHECK-NEXT: gpu.barrier
287+
// CHECK-NEXT: %[[T2:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
288+
// CHECK-NEXT: xegpu.store_nd %[[T1]], %[[T2]] : vector<1xf16>, !xegpu.tensor_desc<16xf16>
289+
gpu.module @test {
290+
gpu.func @gpu_barrier(%arg0: memref<256xf16>, %arg1: memref<256xf16>) {
291+
%c0 = arith.constant 0 : index
292+
%0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
293+
%1 = xegpu.load_nd %0 {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16>
294+
gpu.barrier
295+
%2 = xegpu.create_nd_tdesc %arg1[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
296+
xegpu.store_nd %1, %2 : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
297+
gpu.return
298+
}
299+
}

0 commit comments

Comments
 (0)