Skip to content

Commit b3bf12f

Browse files
committed
Add prefetch_nd op
1 parent 1ed4cb5 commit b3bf12f

File tree

3 files changed

+60
-46
lines changed

3 files changed

+60
-46
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,10 @@ using namespace mlir;
3434
namespace {
3535

3636
// clang-format off
37-
/// This pattern transform the CreateNdDescOp to create a subgroup descriptor
37+
/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
3838
/// from a workgroup descriptor. It replaces the offsets and sizes with
3939
/// appropriate values for the subgroup.
40-
/// It uses round-robin distribution to create the subgroup descriptor.
41-
40+
/// It uses round-robin assignment to distribute the work to the subgroups.
4241
/// Following create_nd_desc operation:,
4342
/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
4443
/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
@@ -47,7 +46,7 @@ namespace {
4746
/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
4847
/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
4948
///
50-
/// The sg_layout and sg_data are dropped from the layout attribute as they are no longer needed.
49+
/// The sg_layout and sg_data attributes are dropped after the pass as they are no longer needed.
5150
///
5251
/// 24x24 matrix distribution example:
5352
/// sg_layout = [4, 4], sg_data = [2, 2]
@@ -72,7 +71,6 @@ namespace {
7271
///
7372
/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be 9
7473
/// distribution units (3x3) in total. Hence the 9 subgroup level operations.
75-
/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
7674
// clang-format on
7775
struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
7876
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
@@ -110,7 +108,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
110108
return rewriter.create<arith::ConstantIndexOp>(loc, value);
111109
}
112110

113-
// Calculate global offset for each subgroup
111+
// Calculate offset for each subgroup
114112
SmallVector<OpFoldResult>
115113
calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
116114
const SmallVector<Value> &originalOffsets,
@@ -122,13 +120,11 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
122120
Value constOffsetY =
123121
createConstantIndex(rewriter, loc, distUnitBaseAddr[1]);
124122

125-
// Compute offsets within entire tile
126123
Value offsetX =
127124
rewriter.createOrFold<index::AddOp>(loc, localOffset[0], constOffsetX);
128125
Value offsetY =
129126
rewriter.createOrFold<index::AddOp>(loc, localOffset[1], constOffsetY);
130127

131-
// Add to global offsets
132128
size_t lastDimIndex = originalOffsets.size() - 1;
133129
size_t secondLastDimIndex = lastDimIndex - 1;
134130

@@ -137,7 +133,6 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
137133
Value globalOffsetY = rewriter.createOrFold<index::AddOp>(
138134
loc, originalOffsets[lastDimIndex], offsetY);
139135

140-
// Create final offset list
141136
SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
142137
originalOffsets.end());
143138
globalOffsets[secondLastDimIndex] = globalOffsetX;
@@ -172,7 +167,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
172167
sgDataDim[i] = createConstantIndex(rewriter, loc, sgShape[i]);
173168
}
174169

175-
// Delinearize the 1D subgroup id into nd coordinates
170+
// Delinearize the 1D subgroup id into 2d
176171
SmallVector<Value> sgIds = delinearizeSubgroupId(
177172
rewriter, loc, linearSgId, sgLayoutDim[0], sgLayoutDim[1]);
178173

@@ -207,8 +202,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
207202
}
208203
};
209204

210-
/// This pattern transforms the LoadNdOp to load from a subgroup descriptor
211-
/// It creates a LoadNdOp op to load the new subgroup src tensor descriptors.
205+
/// This pattern transforms the LoadNdOp to load subgroup data.
212206
struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
213207
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
214208
LogicalResult
@@ -310,7 +304,22 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
310304
}
311305
}
312306
rewriter.replaceOpWithMultiple(op, {newDpasOps});
313-
return mlir::success();
307+
return success();
308+
}
309+
};
310+
311+
/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
312+
struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
313+
using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
314+
LogicalResult
315+
matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
316+
ConversionPatternRewriter &rewriter) const override {
317+
for (auto src : adaptor.getTensorDesc()) {
318+
rewriter.create<xegpu::PrefetchNdOp>(op.getLoc(), TypeRange(), src,
319+
op->getAttrs());
320+
}
321+
rewriter.eraseOp(op);
322+
return success();
314323
}
315324
};
316325

@@ -320,7 +329,8 @@ namespace mlir {
320329
namespace xegpu {
321330
void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns) {
322331
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
323-
WgToSgUpdateNdOffsetOp, WgToSgDpasOp>(patterns.getContext());
332+
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
333+
patterns.getContext());
324334
}
325335
} // namespace xegpu
326336
} // namespace mlir
@@ -345,6 +355,8 @@ void XeGPUWgToSgPass::runOnOperation() {
345355
return storeOp.getTensorDescType();
346356
if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
347357
return updateOp.getType();
358+
if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
359+
return prefetchOp.getTensorDescType();
348360
return xegpu::TensorDescType();
349361
};
350362

@@ -353,12 +365,12 @@ void XeGPUWgToSgPass::runOnOperation() {
353365
};
354366

355367
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
356-
xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp>(
357-
[=](Operation *op) -> bool {
358-
auto tdescTy = getTensorDescType(op);
359-
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
360-
return isLegal(layout);
361-
});
368+
xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
369+
xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
370+
auto tdescTy = getTensorDescType(op);
371+
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
372+
return isLegal(layout);
373+
});
362374

363375
target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
364376
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,10 @@ gpu.module @test_round_robin_assignment {
4242
// CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
4343
// CHECK: %[[ARG_2:.*]]: memref<24x24xf32>
4444
gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>, %c: memref<24x24xf32>) {
45-
// CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}},
46-
// %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32,
47-
// #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-12:
48-
// %[[TDESC1:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] :
49-
// memref<32x24xf32> -> !xegpu.tensor_desc<2x2xf32,
50-
// #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-9:
51-
// %[[TDESC2:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] :
52-
// memref<24x24xf32> -> !xegpu.tensor_desc<2x2xf32,
53-
// #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-144:
54-
// %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout =
55-
// #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} :
56-
// vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
45+
// CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
46+
// CHECK-COUNT-12: %[[TDESC1:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<32x24xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
47+
// CHECK-COUNT-9: %[[TDESC2:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<24x24xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
48+
// CHECK-COUNT-144: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
5749
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
5850
%load_a = xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
5951
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
@@ -62,4 +54,13 @@ gpu.module @test_round_robin_assignment {
6254
%dpas = xegpu.dpas %load_a, %load_b {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
6355
gpu.return
6456
}
57+
58+
// CHECK: test_prefetch_nd_tdesc
59+
// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
60+
gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
61+
// CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
62+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
63+
xegpu.prefetch_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
64+
gpu.return
65+
}
6566
}

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,26 @@ gpu.func @test_update_nd(%src: memref<24x32xf32>){
5757
// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
5858
// CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
5959
gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
60-
// CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}},
61-
// {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32,
62-
// #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> CHECK:
63-
// %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] :
64-
// !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8],
65-
// lane_data = [1, 1]>> -> vector<12x8xf32> CHECK: %[[TDESC_B:.*]] =
66-
// xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> ->
67-
// !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2],
68-
// lane_data = [1, 1]>> CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] :
69-
// !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2],
70-
// lane_data = [1, 1]>> -> vector<8x12xf32> CHECK: %[[DPAS:.*]] = xegpu.dpas
71-
// %[[LOAD_A]], %[[LOAD_B]] {layout = #xegpu.layout<lane_layout = [2, 2],
72-
// lane_data = [1, 1]>} : vector<12x8xf32>, vector<8x12xf32> ->
73-
// vector<12x12xf32>
60+
// CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
61+
// CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
62+
// CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
63+
// CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> -> vector<8x12xf32>
64+
// CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
7465
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
7566
%load_a = xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
7667
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
7768
%load_b = xegpu.load_nd %tdesc_b: !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>> -> vector<32x24xf32>
7869
%dpas = xegpu.dpas %load_a, %load_b {layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
7970
gpu.return
8071
}
72+
73+
// CHECK: test_prefetch_nd_tdesc
74+
// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
75+
gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
76+
// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
77+
// CHECK: xegpu.prefetch_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
78+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
79+
xegpu.prefetch_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
80+
gpu.return
81+
}
8182
}

0 commit comments

Comments
 (0)