Skip to content

Commit 141c551

Browse files
committed
save work
1 parent 4ea0ef2 commit 141c551

File tree

2 files changed

+52
-43
lines changed

2 files changed

+52
-43
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "llvm/ADT/ArrayRef.h"
3535
#include "llvm/ADT/STLExtras.h"
3636
#include "llvm/ADT/SmallVector.h"
37+
#include "llvm/ADT/SmallVectorExtras.h"
3738

3839
namespace mlir {
3940
namespace xegpu {
@@ -197,9 +198,17 @@ struct MoveFuncBodyToWarpExecuteOnLane0
197198
return isa<gpu::WarpExecuteOnLane0Op>(op);
198199
}))
199200
return failure();
200-
// Create a new function with the same signature.
201+
// Create a new function with the same signature and same attributes.
202+
SmallVector<Type> workgroupAttributionsTypes =
203+
llvm::map_to_vector(gpuFuncOp.getWorkgroupAttributions(),
204+
[](BlockArgument arg) { return arg.getType(); });
205+
SmallVector<Type> privateAttributionsTypes =
206+
llvm::map_to_vector(gpuFuncOp.getPrivateAttributions(),
207+
[](BlockArgument arg) { return arg.getType(); });
201208
auto newGpuFunc = rewriter.create<gpu::GPUFuncOp>(
202-
gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType());
209+
gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType(),
210+
workgroupAttributionsTypes, privateAttributionsTypes);
211+
newGpuFunc->setAttrs(gpuFuncOp->getAttrs());
203212
// Create a WarpExecuteOnLane0Op with same arguments and results as the
204213
// original gpuFuncOp.
205214
rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
@@ -265,13 +274,13 @@ struct MoveFuncBodyToWarpExecuteOnLane0
265274
/// ```
266275
struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
267276
using gpu::WarpDistributionPattern::WarpDistributionPattern;
268-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
277+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
269278
PatternRewriter &rewriter) const override {
270279
OpOperand *operand =
271-
getWarpResult(subgroupOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
280+
getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
272281
if (!operand)
273282
return rewriter.notifyMatchFailure(
274-
subgroupOp, "warp result is not a xegpu::CreateNdDesc op");
283+
warpOp, "warp result is not a xegpu::CreateNdDesc op");
275284
auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
276285
unsigned operandIdx = operand->getOperandNumber();
277286

@@ -288,9 +297,9 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
288297
newYieldValues.push_back(operand);
289298
newYieldTypes.push_back(operand.getType());
290299
}
291-
rewriter.setInsertionPoint(subgroupOp);
300+
rewriter.setInsertionPoint(warpOp);
292301
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
293-
rewriter, subgroupOp, /* new yieled values = */ newYieldValues,
302+
rewriter, warpOp, /* new yieled values = */ newYieldValues,
294303
/* new yielded types = */ newYieldTypes, newRetIndices);
295304

296305
SmallVector<Value> newDescOperands;
@@ -347,10 +356,10 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
347356
/// ```
348357
struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
349358
using gpu::WarpDistributionPattern::WarpDistributionPattern;
350-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
359+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
351360
PatternRewriter &rewriter) const override {
352361
auto yield = cast<gpu::YieldOp>(
353-
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
362+
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
354363
Operation *lastNode = yield->getPrevNode();
355364
auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
356365
if (!storeOp)
@@ -372,7 +381,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
372381

373382
SmallVector<size_t> newRetIndices;
374383
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
375-
rewriter, subgroupOp,
384+
rewriter, warpOp,
376385
/* new yielded values = */
377386
ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
378387
/* new yielded types = */
@@ -449,21 +458,22 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
449458
/// ```
450459
struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
451460
using gpu::WarpDistributionPattern::WarpDistributionPattern;
452-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
461+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
453462
PatternRewriter &rewriter) const override {
454-
OpOperand *operand =
455-
getWarpResult(subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
463+
OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
464+
if (!isa<xegpu::LoadNdOp>(op))
465+
return false;
466+
// Make sure the same load op is the last operation in the warp op body.
467+
// This ensure that load op is not sinked earlier violating any barrier
468+
// synchronizations.
469+
auto yield = cast<gpu::YieldOp>(
470+
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
471+
return yield->getPrevNode() == op;
472+
});
473+
456474
if (!operand)
457475
return rewriter.notifyMatchFailure(
458-
subgroupOp, "warp result is not a xegpu::LoadNd op");
459-
// Make sure the load op is the last operation in the warp op body. This
460-
// ensure that load op is not sinked earlier violating any barrier
461-
// synchronizations.
462-
auto yield = cast<gpu::YieldOp>(
463-
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
464-
Operation *lastNode = yield->getPrevNode();
465-
if (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
466-
return failure();
476+
warpOp, "warp result is not a xegpu::LoadNd op");
467477

468478
auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
469479
xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
@@ -474,11 +484,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
474484

475485
unsigned operandIdx = operand->getOperandNumber();
476486
VectorType distributedTypeByWarpOp =
477-
cast<VectorType>(subgroupOp.getResult(operandIdx).getType());
487+
cast<VectorType>(warpOp.getResult(operandIdx).getType());
478488

479489
SmallVector<size_t> newRetIndices;
480490
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
481-
rewriter, subgroupOp,
491+
rewriter, warpOp,
482492
/* new yielded values = */ loadOp.getTensorDesc(),
483493
/* new yielded types = */ tensorDescTy, newRetIndices);
484494

@@ -548,12 +558,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
548558
/// ```
549559
struct DpasDistribution final : public gpu::WarpDistributionPattern {
550560
using gpu::WarpDistributionPattern::WarpDistributionPattern;
551-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
561+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
552562
PatternRewriter &rewriter) const override {
553-
OpOperand *operand =
554-
getWarpResult(subgroupOp, llvm::IsaPred<xegpu::DpasOp>);
563+
OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<xegpu::DpasOp>);
555564
if (!operand)
556-
return rewriter.notifyMatchFailure(subgroupOp,
565+
return rewriter.notifyMatchFailure(warpOp,
557566
"warp result is not a xegpu::Dpas op");
558567

559568
auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
@@ -599,7 +608,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
599608
// Create a new warp op without the dpas.
600609
SmallVector<size_t> newRetIndices;
601610
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
602-
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
611+
rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
603612

604613
FailureOr<VectorType> expectedDistLhsTyOrFailure =
605614
xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
@@ -678,13 +687,13 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
678687
/// ```
679688
struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
680689
using gpu::WarpDistributionPattern::WarpDistributionPattern;
681-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
690+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
682691
PatternRewriter &rewriter) const override {
683692
OpOperand *operand =
684-
getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
693+
getWarpResult(warpOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
685694
if (!operand)
686695
return rewriter.notifyMatchFailure(
687-
subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
696+
warpOp, "warp result is not a xegpu::UpdateNdOffset op");
688697
auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
689698
unsigned operandIdx = operand->getOperandNumber();
690699
// new update op does not have layout attribute.
@@ -703,7 +712,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
703712
}
704713
SmallVector<size_t> newRetIndices;
705714
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
706-
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
715+
rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
707716
rewriter.setInsertionPointAfter(newWarpOp);
708717
SmallVector<Value> newUpdateOperands;
709718
for (size_t i : newRetIndices) {
@@ -758,10 +767,10 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
758767
/// ```
759768
struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
760769
using gpu::WarpDistributionPattern::WarpDistributionPattern;
761-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
770+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
762771
PatternRewriter &rewriter) const override {
763772
auto yield = cast<gpu::YieldOp>(
764-
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
773+
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
765774
Operation *lastNode = yield->getPrevNode();
766775
auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
767776
if (!prefetchOp)
@@ -775,7 +784,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
775784
SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
776785
SmallVector<size_t> newRetIndices;
777786
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
778-
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
787+
rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
779788
// Create a new prefetch op outside the warp op with updated tensor
780789
// descriptor type. Source tensor descriptor require type resolution.
781790
xegpu::TensorDescType newTensorDescTy =
@@ -795,17 +804,17 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
795804
/// region. This will simply move the barrier op outside of the warp op.
796805
struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
797806
using gpu::WarpDistributionPattern::WarpDistributionPattern;
798-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
807+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
799808
PatternRewriter &rewriter) const override {
800809
auto yield = cast<gpu::YieldOp>(
801-
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
810+
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
802811
Operation *lastNode = yield->getPrevNode();
803812
// The last node must be a gpu::BarrierOp.
804813
auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
805814
if (!barrierOp)
806815
return failure();
807816
// Move the barrier op outside of the warp op.
808-
rewriter.setInsertionPointAfter(subgroupOp);
817+
rewriter.setInsertionPointAfter(warpOp);
809818
rewriter.create<gpu::BarrierOp>(
810819
barrierOp.getLoc(), barrierOp->getResultTypes(),
811820
barrierOp->getOperands(), barrierOp->getAttrs());

mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ gpu.module @test {
9595
// -----
9696
// CHECK-LABEL: gpu.func @load_dpas_store
9797
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
98-
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
99-
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
10098
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
10199
// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
100+
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
101+
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
102102
// CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
103103
// CHECK-DAG: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
104104
// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
@@ -120,10 +120,10 @@ gpu.module @test {
120120
// -----
121121
// CHECK-LABEL: gpu.func @load_dpas_postop_store
122122
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
123-
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
124-
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
125123
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
126124
// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
125+
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
126+
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
127127
// CHECK-DAG: %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
128128
// CHECK: %[[T5:.*]] = vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
129129
// CHECK: %[[T6:.*]] = math.exp %[[T5]] {{{.*}}} : vector<8x1xf32>

0 commit comments

Comments
 (0)