Skip to content

Commit 6447c63

Browse files
committed
add prefetch support
1 parent 08d9e7b commit 6447c63

File tree

2 files changed

+83
-4
lines changed

2 files changed

+83
-4
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
#include "mlir/IR/Value.h"
3030
#include "mlir/IR/Visitors.h"
3131
#include "mlir/Interfaces/FunctionInterfaces.h"
32-
#include "mlir/Interfaces/InferTypeOpInterface.h"
3332
#include "mlir/Support/LLVM.h"
3433
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3534
#include "mlir/Transforms/InliningUtils.h"
@@ -300,6 +299,9 @@ class LayoutInfoPropagation
300299
void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
301300
ArrayRef<LayoutInfoLattice *> operands,
302301
ArrayRef<const LayoutInfoLattice *> results);
302+
void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
303+
ArrayRef<LayoutInfoLattice *> operands,
304+
ArrayRef<const LayoutInfoLattice *> results);
303305

304306
void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
305307
ArrayRef<LayoutInfoLattice *> operands,
@@ -352,6 +354,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
352354
.Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
353355
visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
354356
})
357+
.Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
358+
visitPrefetchNdOp(prefetchNdOp, operands, results);
359+
})
355360
/// No need to propagate the layout to operands in CreateNdDescOp because
356361
/// they are scalars (offsets, sizes, etc.).
357362
.Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
@@ -400,6 +405,18 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
400405
propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
401406
}
402407

408+
void LayoutInfoPropagation::visitPrefetchNdOp(
409+
xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
410+
ArrayRef<const LayoutInfoLattice *> results) {
411+
/// Here we assign the default layout to the tensor descriptor operand of
412+
/// prefetch.
413+
auto tdescTy = prefetch.getTensorDescType();
414+
auto prefetchLayout = getDefaultLayoutInfo(
415+
VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
416+
/// Propagate the layout to the source tensor descriptor.
417+
propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
418+
}
419+
403420
/// Propagate the layout of the result tensor to the source tensor descriptor in
404421
/// UpdateNdOffsetOp.
405422
void LayoutInfoPropagation::visitUpdateNdOffsetOp(
@@ -1488,6 +1505,39 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
14881505
}
14891506
};
14901507

1508+
struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
1509+
using gpu::WarpDistributionPattern::WarpDistributionPattern;
1510+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1511+
PatternRewriter &rewriter) const override {
1512+
auto yield = cast<gpu::YieldOp>(
1513+
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
1514+
Operation *lastNode = yield->getPrevNode();
1515+
auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
1516+
if (!prefetchOp)
1517+
return failure();
1518+
auto layout = prefetchOp.getTensorDescType().getLayoutAttr();
1519+
if (!layout)
1520+
return rewriter.notifyMatchFailure(
1521+
prefetchOp, "the source tensor descriptor lacks layout attribute");
1522+
1523+
SmallVector<Value, 1> newYieldValues = {prefetchOp.getTensorDesc()};
1524+
SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
1525+
SmallVector<size_t> newRetIndices;
1526+
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1527+
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1528+
1529+
auto newTensorDescTy = dropLayouts(prefetchOp.getTensorDescType());
1530+
rewriter.setInsertionPointAfter(newWarpOp);
1531+
SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
1532+
newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
1533+
rewriter.create<xegpu::PrefetchNdOp>(
1534+
newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
1535+
removeTemporaryLayoutAttributes(prefetchOp->getAttrs()));
1536+
rewriter.eraseOp(prefetchOp);
1537+
return success();
1538+
}
1539+
};
1540+
14911541
/// Generic pattern for sinking a GPU index operations feeding into yield op
14921542
/// of an enclosing `gpu.warp_execute_on_lane_0` region. The original index op
14931543
/// becomes dead and an equivalent copy of the index op is created outside the
@@ -1562,9 +1612,9 @@ struct XeGPUSubgroupDistributePass final
15621612

15631613
void xegpu::populateXeGPUSubgroupDistributePatterns(
15641614
RewritePatternSet &patterns) {
1565-
patterns
1566-
.add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
1567-
DpasDistribution, UpdateNdOffsetDistribution>(patterns.getContext());
1615+
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
1616+
LoadNdDistribution, DpasDistribution, UpdateNdOffsetDistribution,
1617+
PrefetchNdDistribution>(patterns.getContext());
15681618
/// TODO: Is this the right place to add these patterns?
15691619
patterns.add<GpuIndexOpDistribution<gpu::BlockIdOp>,
15701620
GpuIndexOpDistribution<gpu::BlockDimOp>,

mlir/test/Dialect/XeGPU/subgroup-distribution.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,35 @@ gpu.func @test_uddate_nd_offset_2d(%arg0: memref<256x256xf32>){
178178
}
179179
}
180180

181+
// -----
182+
// CHECK-LABEL: gpu.func @test_prefetch_2d
183+
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf16>) {
184+
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
185+
// CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16>
186+
gpu.module @test {
187+
gpu.func @test_prefetch_2d(%arg0: memref<256x256xf16>){
188+
%c0 = arith.constant 0 : index
189+
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
190+
xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16x16xf16>
191+
gpu.return
192+
}
193+
}
194+
195+
// -----
196+
// CHECK-LABEL: gpu.func @test_prefetch_1d
197+
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
198+
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
199+
// CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16>
200+
gpu.module @test {
201+
gpu.func @test_prefetch_1d(%arg0: memref<256xf16>){
202+
%c0 = arith.constant 0 : index
203+
%0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
204+
xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16xf16>
205+
gpu.return
206+
}
207+
}
208+
209+
181210
// -----
182211
// CHECK-LABEL: gpu.func @test_gemm_loop
183212
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {

0 commit comments

Comments
 (0)