Skip to content

Commit 9d24920

Browse files
committed
cleanup and add patterns for rest nd ops
1 parent 6fef430 commit 9d24920

File tree

3 files changed

+105
-76
lines changed

3 files changed

+105
-76
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 76 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,6 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
7373
std::optional<SmallVector<Type>>
7474
convertType(ShapedType type, llvm::ArrayRef<int64_t> blockSize) const {
7575
auto elemTy = type.getElementType();
76-
auto maybeGrids = computeGrids(type.getShape(), blockSize);
77-
78-
if (!maybeGrids)
79-
return std::nullopt;
80-
8176
Type newTy;
8277
// TensorDescType needs to drop the inst_data field in the layout attribute
8378
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
@@ -90,7 +85,9 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
9085
newTy = type.clone(blockSize, elemTy);
9186
}
9287

93-
return llvm::SmallVector<Type>(computeProduct(*maybeGrids), newTy);
88+
auto ratio = computeShapeRatio(type.getShape(), blockSize);
89+
assert(ratio && "Expecting the ratio to be valid.");
90+
return llvm::SmallVector<Type>(computeProduct(*ratio), newTy);
9491
}
9592

9693
// emulate the the unpack behavior using insert_strided_slice for VectorType
@@ -114,16 +111,15 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
114111
}
115112
}
116113
return result;
114+
}
117115

118-
} else if (isa<xegpu::TensorDescType>(destTy)) {
116+
if (isa<xegpu::TensorDescType>(destTy)) {
119117
auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
120118
rewriter.getUnitAttr());
121-
auto innerBlkAttr =
122-
NamedAttribute(rewriter.getStringAttr(blockAttrName),
123-
rewriter.getDenseI64ArrayAttr(blockSize));
119+
auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
120+
rewriter.getDenseI64ArrayAttr(blockSize));
124121
auto castOp = rewriter.create<UnrealizedConversionCastOp>(
125-
loc, destTy, srcs,
126-
llvm::ArrayRef<NamedAttribute>({attr, innerBlkAttr}));
122+
loc, destTy, srcs, llvm::ArrayRef<NamedAttribute>({attr, blkAttr}));
127123
return castOp.getResult(0);
128124
}
129125

@@ -150,15 +146,15 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
150146
}
151147
}
152148
return results;
153-
} else if (isa<xegpu::TensorDescType>(src.getType())) {
149+
}
150+
151+
if (isa<xegpu::TensorDescType>(src.getType())) {
154152
auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
155153
rewriter.getUnitAttr());
156-
auto innerBlkAttr =
157-
NamedAttribute(rewriter.getStringAttr(blockAttrName),
158-
rewriter.getDenseI64ArrayAttr(blockSize));
154+
auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
155+
rewriter.getDenseI64ArrayAttr(blockSize));
159156
auto castOp = rewriter.create<UnrealizedConversionCastOp>(
160-
loc, destTypes, src,
161-
llvm::ArrayRef<NamedAttribute>({attr, innerBlkAttr}));
157+
loc, destTypes, src, llvm::ArrayRef<NamedAttribute>({attr, blkAttr}));
162158
return castOp.getResults();
163159
}
164160

@@ -242,11 +238,70 @@ struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
242238
}
243239
};
244240

241+
struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
242+
using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
243+
LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
244+
PatternRewriter &rewriter) const override {
245+
auto loc = op.getLoc();
246+
auto tdesc = op.getTensorDesc();
247+
auto tdescTy = tdesc.getType();
248+
auto shape = tdescTy.getShape();
249+
250+
auto maybeTargetShape = getTargetShape(op);
251+
if (!maybeTargetShape)
252+
return failure();
253+
auto targetShape = *maybeTargetShape;
254+
255+
auto maybeGrids = computeGrids(shape, targetShape);
256+
if (!maybeGrids)
257+
return failure();
258+
auto grids = *maybeGrids;
259+
260+
auto convertedTdescTypes = convertType(tdescTy, targetShape);
261+
auto convertedTdesc =
262+
pack(tdesc, *convertedTdescTypes, targetShape, loc, rewriter);
263+
264+
llvm::SmallVector<Value> newOps;
265+
for (auto t : convertedTdesc) {
266+
auto newOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
267+
loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
268+
newOps.push_back(newOp);
269+
}
270+
auto castOp = unpack(newOps, op.getType(), targetShape, loc, rewriter);
271+
rewriter.replaceOp(op, castOp);
272+
return success();
273+
}
274+
};
275+
245276
struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
246277
using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
247278
LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
248279
PatternRewriter &rewriter) const override {
249-
return failure();
280+
auto loc = op.getLoc();
281+
auto tdesc = op.getTensorDesc();
282+
auto tdescTy = tdesc.getType();
283+
auto shape = tdescTy.getShape();
284+
285+
auto maybeTargetShape = getTargetShape(op);
286+
if (!maybeTargetShape)
287+
return failure();
288+
auto targetShape = *maybeTargetShape;
289+
290+
auto maybeGrids = computeGrids(shape, targetShape);
291+
if (!maybeGrids)
292+
return failure();
293+
auto grids = *maybeGrids;
294+
295+
auto convertedTdescTypes = convertType(tdescTy, targetShape);
296+
auto convertedTdesc =
297+
pack(tdesc, *convertedTdescTypes, targetShape, loc, rewriter);
298+
299+
for (auto t : convertedTdesc) {
300+
rewriter.create<xegpu::PrefetchNdOp>(loc, TypeRange(), t, op->getAttrs());
301+
}
302+
303+
rewriter.eraseOp(op);
304+
return success();
250305
}
251306
};
252307

@@ -333,54 +388,6 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
333388
}
334389
};
335390

336-
struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
337-
using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
338-
LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
339-
PatternRewriter &rewriter) const override {
340-
return failure();
341-
}
342-
};
343-
344-
struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
345-
using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
346-
LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
347-
PatternRewriter &rewriter) const override {
348-
return failure();
349-
}
350-
};
351-
352-
struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
353-
using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
354-
LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
355-
PatternRewriter &rewriter) const override {
356-
return failure();
357-
}
358-
};
359-
360-
struct UnrollLoadOp : public UnrollPattern<xegpu::LoadGatherOp> {
361-
using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
362-
LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
363-
PatternRewriter &rewriter) const override {
364-
return failure();
365-
}
366-
};
367-
368-
struct UnrollStoreOp : public UnrollPattern<xegpu::StoreScatterOp> {
369-
using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
370-
LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
371-
PatternRewriter &rewriter) const override {
372-
return failure();
373-
}
374-
};
375-
376-
struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
377-
using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
378-
LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
379-
PatternRewriter &rewriter) const override {
380-
return failure();
381-
}
382-
};
383-
384391
struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
385392
using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
386393
LogicalResult matchAndRewrite(xegpu::DpasOp op,
@@ -468,18 +475,12 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
468475
}
469476
};
470477

471-
struct UnrollAtomicRMWOp : public UnrollPattern<xegpu::AtomicRMWOp> {
472-
using UnrollPattern<xegpu::AtomicRMWOp>::UnrollPattern;
473-
LogicalResult matchAndRewrite(xegpu::AtomicRMWOp op,
474-
PatternRewriter &rewriter) const override {
475-
return failure();
476-
}
477-
};
478478
} // namespace
479479

480480
void mlir::xegpu::populateXeGPUUnrollPatterns(
481481
RewritePatternSet &patterns,
482482
const mlir::vector::UnrollVectorOptions &options) {
483-
patterns.add<UnrollCreateNdOp, UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
483+
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
484+
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
484485
patterns.getContext(), options);
485486
}

mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,29 @@ gpu.module @test {
1717

1818
//-----
1919

20+
// CHECK-LABEL: test_update_nd_tdesc
21+
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
22+
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
23+
// CHECK-COUNT-6: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf32>
24+
gpu.func @test_update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
25+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
26+
%update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
27+
gpu.return %update : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
28+
}
29+
30+
//-----
31+
32+
// CHECK-LABEL: test_prefetch_nd_tdesc
33+
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
34+
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
35+
// CHECK-COUNT-6: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<8x16xf32>
36+
gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
37+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
38+
xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
39+
gpu.return
40+
}
41+
42+
//-----
2043
// CHECK-LABEL: test_load_nd
2144
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
2245
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>

mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,15 @@ struct TestXeGPUUnrollingPatterns
4646
vector::UnrollVectorOptions options;
4747
options.setNativeShapeFn([&](Operation *op)
4848
-> std::optional<SmallVector<int64_t>> {
49-
if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
49+
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
50+
xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
5051
xegpu::TensorDescType tdescTy;
5152
if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
5253
tdescTy = createNdOp.getType();
54+
} else if (auto updateNdOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
55+
tdescTy = updateNdOp.getTensorDescType();
56+
} else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
57+
tdescTy = prefetchNdOp.getTensorDescType();
5358
} else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
5459
tdescTy = loadNdOp.getTensorDescType();
5560
} else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {

0 commit comments

Comments
 (0)