@@ -396,11 +396,50 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
396396 }
397397};
398398
399+ struct UnrollCreateDescOp : public UnrollPattern <xegpu::CreateDescOp> {
400+ using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
401+ LogicalResult matchAndRewrite (xegpu::CreateDescOp op,
402+ PatternRewriter &rewriter) const override {
403+ Location loc = op.getLoc ();
404+ xegpu::TensorDescType tdescTy = op.getType ();
405+
406+ std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
407+ if (!targetShape)
408+ return failure ();
409+
410+ auto newTdescTy = getUnrolledTypes (tdescTy, *targetShape)[0 ];
411+
412+
413+ TypedValue<::mlir::VectorType> indiceVec = op.getOffsets ();
414+
415+ VectorType indiceVecTy = indiceVec.getType ();
416+
417+ SmallVector<Type> convertedIndiceTypes =
418+ getUnrolledTypes (indiceVecTy, *targetShape);
419+
420+ SmallVector<Value> convertedIndiceVec =
421+ pack (indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
422+
423+ SmallVector<Value> newOps;
424+
425+ for (auto indice : convertedIndiceVec) {
426+ auto newOp = rewriter.create <xegpu::CreateDescOp>(loc, newTdescTy, op.getSource (), indice);
427+ newOps.push_back (newOp);
428+ }
429+
430+ Value castOp = unpack (newOps, tdescTy, *targetShape, loc, rewriter);
431+ rewriter.replaceOp (op, castOp);
432+
433+ return success ();
434+ }
435+ };
436+
399437} // namespace
400438
401439void mlir::xegpu::populateXeGPUUnrollPatterns (
402440 RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
403441 patterns.add <UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
404- UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
442+ UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
443+ UnrollCreateDescOp>(
405444 patterns.getContext (), options);
406445}
0 commit comments