Skip to content

Commit 248981b

Browse files
committed
add support for create_tdesc with chunk_size
1 parent b3d5937 commit 248981b

File tree

2 files changed

+81
-30
lines changed

2 files changed

+81
-30
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
404404
xegpu::TensorDescType tdescTy = op.getType();
405405

406406
// check if the tensor descriptor type is a 1d vector type
407-
if (tdescTy.getRank() > 1)
407+
if (tdescTy.getRank() > 2)
408408
return failure();
409409

410410
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -416,16 +416,47 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
416416
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
417417
VectorType indiceVecTy = indiceVec.getType();
418418

419-
SmallVector<Type> convertedIndiceTypes =
420-
getUnrolledTypes(indiceVecTy, *targetShape);
421-
SmallVector<Value> convertedIndiceVec =
422-
pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
423-
419+
SmallVector<Type> convertedIndiceTypes;
420+
SmallVector<Value> convertedIndiceVec;
424421
SmallVector<Value> newOps;
425-
for (auto indice : convertedIndiceVec) {
426-
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
422+
423+
if (tdescTy.getRank() == 2) {
424+
SmallVector<int64_t> oneDShape(targetShape->begin(), targetShape->end() - 1);
425+
convertedIndiceTypes = getUnrolledTypes(indiceVecTy, oneDShape);
426+
convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, oneDShape, loc, rewriter);
427+
// Assume tdescTy, targetShape, and convertedIndiceVec are defined
428+
int64_t outerDim = tdescTy.getShape().back();
429+
int64_t innerDim = targetShape->back();
430+
int64_t numInnerLoops = outerDim / innerDim;
431+
432+
// Get element size in bytes
433+
int64_t elemSize = tdescTy.getElementType().getIntOrFloatBitWidth() / 8;
434+
435+
for (auto [indice, indiceType] : llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
436+
for (int64_t i = 0; i < numInnerLoops; ++i) {
437+
// Compute the offset
438+
Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * innerDim);
439+
Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
440+
Value offsetIndice = rewriter.create<arith::AddIOp>(loc, indice, incVec);
441+
442+
auto chunkSizeAttr = rewriter.getI64IntegerAttr(innerDim);
443+
auto newOp = rewriter.create<xegpu::CreateDescOp>(
444+
loc, newTdescTy, op.getSource(), offsetIndice);
445+
446+
newOps.push_back(newOp);
447+
}
448+
}
449+
} else if (tdescTy.getRank() == 1) {
450+
convertedIndiceTypes = getUnrolledTypes(indiceVecTy, *targetShape);
451+
convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
452+
for (auto indice : convertedIndiceVec) {
453+
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
427454
op.getSource(), indice);
428-
newOps.push_back(newOp);
455+
newOps.push_back(newOp);
456+
}
457+
} else {
458+
// Unsupported rank for tensor descriptor
459+
return failure();
429460
}
430461

431462
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
@@ -445,9 +476,9 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
445476
xegpu::TensorDescType tdescTy = op.getTensorDescType();
446477

447478
// check if the tensor descriptor type is a 1d vector type
448-
if (tdescTy.getRank() > 1)
479+
if (tdescTy.getRank() > 2)
449480
return failure();
450-
481+
451482
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
452483

453484
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);

mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ using namespace mlir::xegpu;
1919

2020
namespace {
2121

22+
23+
#define DEBUG_TYPE "test-xegpu-unroll"
24+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
25+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
26+
27+
2228
struct TestXeGPUUnrollingPatterns
2329
: public PassWrapper<TestXeGPUUnrollingPatterns,
2430
OperationPass<gpu::GPUModuleOp>> {
@@ -48,33 +54,21 @@ struct TestXeGPUUnrollingPatterns
4854
options.setNativeShapeFn(
4955
[&](Operation *op) -> std::optional<SmallVector<int64_t>> {
5056
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
51-
xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
57+
xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
58+
xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
59+
xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
5260
xegpu::TensorDescType tdescTy;
5361
if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
5462
tdescTy = createNdOp.getType();
55-
} else if (auto updateNdOp =
56-
dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
63+
} else if (auto updateNdOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
5764
tdescTy = updateNdOp.getTensorDescType();
5865
} else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
5966
tdescTy = prefetchNdOp.getTensorDescType();
6067
} else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
6168
tdescTy = loadNdOp.getTensorDescType();
6269
} else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
6370
tdescTy = storeNdOp.getTensorDescType();
64-
}
65-
66-
if (auto layout = tdescTy.getLayoutAttr()) {
67-
auto inst_data = layout.getInstData();
68-
if (inst_data && layout.isSgLayout())
69-
return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
70-
inst_data.asArrayRef().end());
71-
}
72-
}
73-
74-
if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
75-
xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
76-
xegpu::TensorDescType tdescTy;
77-
if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
71+
} else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
7872
tdescTy = createOp.getType();
7973
} else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
8074
tdescTy = updateOp.getTensorDescType();
@@ -111,14 +105,40 @@ struct TestXeGPUUnrollingPatterns
111105
Attribute encoding = tdescTy.getEncoding();
112106
auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
113107
tdescTy.getLayout());
108+
109+
int64_t newChunkSize = 0;
110+
auto instData = layout.getInstData();
111+
if (!instData.empty())
112+
newChunkSize = instData.asArrayRef().back();
113+
114114
if (layout) {
115115
if (layout.getLaneLayout() == nullptr)
116116
layout = xegpu::LayoutAttr();
117117
else
118118
layout = layout.dropInstData();
119119
}
120-
newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
121-
layout);
120+
121+
SmallVector<NamedAttribute> attrs;
122+
auto scatterAttr = mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
123+
if (scatterAttr) {
124+
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
125+
126+
if (chunkSize > 1) {
127+
128+
auto chunkSizeAttr = mlir::IntegerAttr::get(
129+
mlir::IntegerType::get(ctx, 64), newChunkSize);
130+
131+
// To create a new attribute with a different chunk_size:
132+
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
133+
ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
134+
135+
encoding = newEncoding;
136+
137+
}
138+
139+
}
140+
newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, layout);
141+
122142
} else {
123143
newTy = type.clone(tileShape, elemTy);
124144
}

0 commit comments

Comments
 (0)