Skip to content

Commit 30b099e

Browse files
committed
add create_desc unrolling and test
1 parent bac8bc6 commit 30b099e

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,11 +396,50 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
396396
}
397397
};
398398

399+
struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
400+
using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
401+
LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
402+
PatternRewriter &rewriter) const override {
403+
Location loc = op.getLoc();
404+
xegpu::TensorDescType tdescTy = op.getType();
405+
406+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
407+
if (!targetShape)
408+
return failure();
409+
410+
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
411+
412+
413+
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
414+
415+
VectorType indiceVecTy = indiceVec.getType();
416+
417+
SmallVector<Type> convertedIndiceTypes =
418+
getUnrolledTypes(indiceVecTy, *targetShape);
419+
420+
SmallVector<Value> convertedIndiceVec =
421+
pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
422+
423+
SmallVector<Value> newOps;
424+
425+
for (auto indice : convertedIndiceVec) {
426+
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy, op.getSource(), indice);
427+
newOps.push_back(newOp);
428+
}
429+
430+
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
431+
rewriter.replaceOp(op, castOp);
432+
433+
return success();
434+
}
435+
};
436+
399437
} // namespace
400438

401439
void mlir::xegpu::populateXeGPUUnrollPatterns(
402440
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
403441
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
404-
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
442+
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
443+
UnrollCreateDescOp>(
405444
patterns.getContext(), options);
406445
}

mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,30 @@ struct TestXeGPUUnrollingPatterns
7171
}
7272
}
7373

74+
if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp,
75+
xegpu::PrefetchOp, xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
76+
xegpu::TensorDescType tdescTy;
77+
if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
78+
tdescTy = createOp.getType();
79+
} else if (auto updateOp =
80+
dyn_cast<xegpu::UpdateOffsetOp>(op)) {
81+
tdescTy = updateOp.getTensorDescType();
82+
} else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
83+
tdescTy = prefetchOp.getTensorDescType();
84+
} else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
85+
tdescTy = loadOp.getTensorDescType();
86+
} else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
87+
tdescTy = storeOp.getTensorDescType();
88+
}
89+
90+
if (auto layout = tdescTy.getLayoutAttr()) {
91+
auto inst_data = layout.getInstData();
92+
if (inst_data && layout.isSgLayout())
93+
return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
94+
inst_data.asArrayRef().end());
95+
}
96+
}
97+
7498
if (isa<xegpu::DpasOp>(op))
7599
return SmallVector<int64_t>{8, 16, 16};
76100

0 commit comments

Comments
 (0)