@@ -73,11 +73,6 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
7373 std::optional<SmallVector<Type>>
7474 convertType (ShapedType type, llvm::ArrayRef<int64_t > blockSize) const {
7575 auto elemTy = type.getElementType ();
76- auto maybeGrids = computeGrids (type.getShape (), blockSize);
77-
78- if (!maybeGrids)
79- return std::nullopt ;
80-
8176 Type newTy;
8277 // TensorDescType needs to drop the inst_data field in the layout attribute
8378 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
@@ -90,7 +85,9 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
9085 newTy = type.clone (blockSize, elemTy);
9186 }
9287
93- return llvm::SmallVector<Type>(computeProduct (*maybeGrids), newTy);
88+ auto ratio = computeShapeRatio (type.getShape (), blockSize);
89+ assert (ratio && " Expecting the ratio to be valid." );
90+ return llvm::SmallVector<Type>(computeProduct (*ratio), newTy);
9491 }
9592
9693 // emulate the the unpack behavior using insert_strided_slice for VectorType
@@ -114,16 +111,15 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
114111 }
115112 }
116113 return result;
114+ }
117115
118- } else if (isa<xegpu::TensorDescType>(destTy)) {
116+ if (isa<xegpu::TensorDescType>(destTy)) {
119117 auto attr = NamedAttribute (rewriter.getStringAttr (unpackAttrName),
120118 rewriter.getUnitAttr ());
121- auto innerBlkAttr =
122- NamedAttribute (rewriter.getStringAttr (blockAttrName),
123- rewriter.getDenseI64ArrayAttr (blockSize));
119+ auto blkAttr = NamedAttribute (rewriter.getStringAttr (blockAttrName),
120+ rewriter.getDenseI64ArrayAttr (blockSize));
124121 auto castOp = rewriter.create <UnrealizedConversionCastOp>(
125- loc, destTy, srcs,
126- llvm::ArrayRef<NamedAttribute>({attr, innerBlkAttr}));
122+ loc, destTy, srcs, llvm::ArrayRef<NamedAttribute>({attr, blkAttr}));
127123 return castOp.getResult (0 );
128124 }
129125
@@ -150,15 +146,15 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
150146 }
151147 }
152148 return results;
153- } else if (isa<xegpu::TensorDescType>(src.getType ())) {
149+ }
150+
151+ if (isa<xegpu::TensorDescType>(src.getType ())) {
154152 auto attr = NamedAttribute (rewriter.getStringAttr (packAttrName),
155153 rewriter.getUnitAttr ());
156- auto innerBlkAttr =
157- NamedAttribute (rewriter.getStringAttr (blockAttrName),
158- rewriter.getDenseI64ArrayAttr (blockSize));
154+ auto blkAttr = NamedAttribute (rewriter.getStringAttr (blockAttrName),
155+ rewriter.getDenseI64ArrayAttr (blockSize));
159156 auto castOp = rewriter.create <UnrealizedConversionCastOp>(
160- loc, destTypes, src,
161- llvm::ArrayRef<NamedAttribute>({attr, innerBlkAttr}));
157+ loc, destTypes, src, llvm::ArrayRef<NamedAttribute>({attr, blkAttr}));
162158 return castOp.getResults ();
163159 }
164160
@@ -242,11 +238,70 @@ struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
242238 }
243239};
244240
241+ struct UnrollUpdateNdOffsetOp : public UnrollPattern <xegpu::UpdateNdOffsetOp> {
242+ using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
243+ LogicalResult matchAndRewrite (xegpu::UpdateNdOffsetOp op,
244+ PatternRewriter &rewriter) const override {
245+ auto loc = op.getLoc ();
246+ auto tdesc = op.getTensorDesc ();
247+ auto tdescTy = tdesc.getType ();
248+ auto shape = tdescTy.getShape ();
249+
250+ auto maybeTargetShape = getTargetShape (op);
251+ if (!maybeTargetShape)
252+ return failure ();
253+ auto targetShape = *maybeTargetShape;
254+
255+ auto maybeGrids = computeGrids (shape, targetShape);
256+ if (!maybeGrids)
257+ return failure ();
258+ auto grids = *maybeGrids;
259+
260+ auto convertedTdescTypes = convertType (tdescTy, targetShape);
261+ auto convertedTdesc =
262+ pack (tdesc, *convertedTdescTypes, targetShape, loc, rewriter);
263+
264+ llvm::SmallVector<Value> newOps;
265+ for (auto t : convertedTdesc) {
266+ auto newOp = rewriter.create <xegpu::UpdateNdOffsetOp>(
267+ loc, t.getType (), t, op.getOffsets (), op.getConstOffsets ());
268+ newOps.push_back (newOp);
269+ }
270+ auto castOp = unpack (newOps, op.getType (), targetShape, loc, rewriter);
271+ rewriter.replaceOp (op, castOp);
272+ return success ();
273+ }
274+ };
275+
245276struct UnrollPrefetchNdOp : public UnrollPattern <xegpu::PrefetchNdOp> {
246277 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
247278 LogicalResult matchAndRewrite (xegpu::PrefetchNdOp op,
248279 PatternRewriter &rewriter) const override {
249- return failure ();
280+ auto loc = op.getLoc ();
281+ auto tdesc = op.getTensorDesc ();
282+ auto tdescTy = tdesc.getType ();
283+ auto shape = tdescTy.getShape ();
284+
285+ auto maybeTargetShape = getTargetShape (op);
286+ if (!maybeTargetShape)
287+ return failure ();
288+ auto targetShape = *maybeTargetShape;
289+
290+ auto maybeGrids = computeGrids (shape, targetShape);
291+ if (!maybeGrids)
292+ return failure ();
293+ auto grids = *maybeGrids;
294+
295+ auto convertedTdescTypes = convertType (tdescTy, targetShape);
296+ auto convertedTdesc =
297+ pack (tdesc, *convertedTdescTypes, targetShape, loc, rewriter);
298+
299+ for (auto t : convertedTdesc) {
300+ rewriter.create <xegpu::PrefetchNdOp>(loc, TypeRange (), t, op->getAttrs ());
301+ }
302+
303+ rewriter.eraseOp (op);
304+ return success ();
250305 }
251306};
252307
@@ -333,54 +388,6 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
333388 }
334389};
335390
336- struct UnrollUpdateNdOffsetOp : public UnrollPattern <xegpu::UpdateNdOffsetOp> {
337- using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
338- LogicalResult matchAndRewrite (xegpu::UpdateNdOffsetOp op,
339- PatternRewriter &rewriter) const override {
340- return failure ();
341- }
342- };
343-
344- struct UnrollCreateDescOp : public UnrollPattern <xegpu::CreateDescOp> {
345- using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
346- LogicalResult matchAndRewrite (xegpu::CreateDescOp op,
347- PatternRewriter &rewriter) const override {
348- return failure ();
349- }
350- };
351-
352- struct UnrollPrefetchOp : public UnrollPattern <xegpu::PrefetchOp> {
353- using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
354- LogicalResult matchAndRewrite (xegpu::PrefetchOp op,
355- PatternRewriter &rewriter) const override {
356- return failure ();
357- }
358- };
359-
360- struct UnrollLoadOp : public UnrollPattern <xegpu::LoadGatherOp> {
361- using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
362- LogicalResult matchAndRewrite (xegpu::LoadGatherOp op,
363- PatternRewriter &rewriter) const override {
364- return failure ();
365- }
366- };
367-
368- struct UnrollStoreOp : public UnrollPattern <xegpu::StoreScatterOp> {
369- using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
370- LogicalResult matchAndRewrite (xegpu::StoreScatterOp op,
371- PatternRewriter &rewriter) const override {
372- return failure ();
373- }
374- };
375-
376- struct UnrollUpdateOffsetOp : public UnrollPattern <xegpu::UpdateOffsetOp> {
377- using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
378- LogicalResult matchAndRewrite (xegpu::UpdateOffsetOp op,
379- PatternRewriter &rewriter) const override {
380- return failure ();
381- }
382- };
383-
384391struct UnrollDpasOp : public UnrollPattern <xegpu::DpasOp> {
385392 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
386393 LogicalResult matchAndRewrite (xegpu::DpasOp op,
@@ -468,18 +475,12 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
468475 }
469476};
470477
471- struct UnrollAtomicRMWOp : public UnrollPattern <xegpu::AtomicRMWOp> {
472- using UnrollPattern<xegpu::AtomicRMWOp>::UnrollPattern;
473- LogicalResult matchAndRewrite (xegpu::AtomicRMWOp op,
474- PatternRewriter &rewriter) const override {
475- return failure ();
476- }
477- };
478478} // namespace
479479
480480void mlir::xegpu::populateXeGPUUnrollPatterns (
481481 RewritePatternSet &patterns,
482482 const mlir::vector::UnrollVectorOptions &options) {
483- patterns.add <UnrollCreateNdOp, UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
483+ patterns.add <UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
484+ UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
484485 patterns.getContext (), options);
485486}
0 commit comments