@@ -56,8 +56,9 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
5656 }
5757
5858 SmallVector<Type> getUnrolledTypes (ShapedType type,
59- ArrayRef<int64_t > tileShape) const {
60- return options.getUnrolledTypes (type, tileShape);
59+ ArrayRef<int64_t > tileShape,
60+ bool returnSingleType = false ) const {
61+ return options.getUnrolledTypes (type, tileShape, returnSingleType);
6162 }
6263
6364 // / Emulate the the unpack behavior using insert_strided_slice for VectorType
@@ -121,53 +122,79 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
121122 xegpu::UnrollOptions options;
122123};
123124
125+ // Generic helper function for unrolling operations with offsets.
126+ //
127+ // Iterates over tile offsets within the tensor descriptor shape and calls
128+ // the provided createOp function for each computed offset. This is used by
129+ // operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they
130+ // have explicit offsets that need to be adjusted for each unrolled tile.
131+ SmallVector<Value> computeUnrolledOffsets (
132+ SmallVector<OpFoldResult> mixedOffsets, xegpu::TensorDescType tdescTy,
133+ ArrayRef<int64_t > targetShape,
134+ const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
135+ Location loc, PatternRewriter &rewriter) {
136+ int64_t rank = tdescTy.getRank ();
137+ ArrayRef<int64_t > shape = tdescTy.getShape ();
138+
139+ auto addi = [&](OpFoldResult a, int64_t b) -> Value {
140+ std::optional<int64_t > maybeInt = getConstantIntValue (a);
141+ if (maybeInt) {
142+ return arith::ConstantIndexOp::create (rewriter, loc, *maybeInt + b);
143+ } else {
144+ auto aV = llvm::cast<Value>(a);
145+ auto bV = arith::ConstantIndexOp::create (rewriter, loc, b);
146+ return rewriter.createOrFold <arith::AddIOp>(loc, aV, bV);
147+ }
148+ };
149+
150+ SmallVector<OpFoldResult> oldOffsets = llvm::to_vector (
151+ llvm::drop_begin (mixedOffsets, mixedOffsets.size () - rank));
152+ auto validIdxes =
153+ llvm::seq<int64_t >(mixedOffsets.size () - rank, mixedOffsets.size ());
154+
155+ SmallVector<Value> newOps;
156+ for (SmallVector<int64_t > offsets :
157+ StaticTileOffsetRange (shape, targetShape)) {
158+
159+ for (auto [idx, oldOff, offset] :
160+ llvm::zip (validIdxes, oldOffsets, offsets))
161+ mixedOffsets[idx] = addi (oldOff, offset);
162+
163+ auto newOp = createOp (mixedOffsets);
164+ newOps.push_back (newOp);
165+ }
166+ return newOps;
167+ }
168+
124169struct UnrollCreateNdOp : public UnrollPattern <xegpu::CreateNdDescOp> {
125170 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
126171 LogicalResult matchAndRewrite (xegpu::CreateNdDescOp op,
127172 PatternRewriter &rewriter) const override {
128173 Location loc = op.getLoc ();
129174 xegpu::TensorDescType tdescTy = op.getType ();
130- int64_t rank = tdescTy.getRank ();
131- ArrayRef<int64_t > shape = tdescTy.getShape ();
132175
133176 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
134177 if (!targetShape)
135178 return failure ();
136179
137- auto newTdescTy = getUnrolledTypes (tdescTy, *targetShape)[0 ];
138-
139- auto addi = [&](OpFoldResult a, int64_t b) -> Value {
140- std::optional<int64_t > maybeInt = getConstantIntValue (a);
141- if (maybeInt) {
142- return arith::ConstantIndexOp::create (rewriter, loc, *maybeInt + b);
143- } else {
144- auto aV = llvm::cast<Value>(a);
145- auto bV = arith::ConstantIndexOp::create (rewriter, loc, b);
146- return rewriter.createOrFold <arith::AddIOp>(loc, aV, bV);
147- }
148- };
149-
150- SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets ();
151-
152- // For n-D memrefs where n > rank, we need to handle the last `rank`
153- // dimensions only, and keep the first `n-rank` dimensions as is.
154- SmallVector<OpFoldResult> oldOffsets = llvm::to_vector (
155- llvm::drop_begin (mixedOffsets, mixedOffsets.size () - rank));
156- auto validIdxes =
157- llvm::seq<int64_t >(mixedOffsets.size () - rank, mixedOffsets.size ());
158-
159180 SmallVector<Value> newOps;
160- for (SmallVector<int64_t > offsets :
161- StaticTileOffsetRange (shape, *targetShape)) {
162-
163- for (auto [idx, oldOff, offset] :
164- llvm::zip (validIdxes, oldOffsets, offsets))
165- mixedOffsets[idx] = addi (oldOff, offset);
166181
182+ auto newTdescTy = getUnrolledTypes (tdescTy, *targetShape)[0 ];
183+ bool hasOffsets = op.getMixedOffsets ().size () != 0 ;
184+ if (!hasOffsets) {
167185 auto newOp = xegpu::CreateNdDescOp::create (
168- rewriter, loc, newTdescTy, op.getSource (), mixedOffsets ,
169- op.getMixedSizes (), op. getMixedStrides ());
186+ rewriter, loc, newTdescTy, op.getSource (), op. getMixedSizes () ,
187+ op.getMixedStrides ());
170188 newOps.push_back (newOp);
189+ } else {
190+ auto createOp = [&](SmallVector<OpFoldResult> offsets) -> Value {
191+ return xegpu::CreateNdDescOp::create (
192+ rewriter, loc, newTdescTy, op.getSource (), offsets,
193+ op.getMixedSizes (), op.getMixedStrides ());
194+ };
195+
196+ newOps = computeUnrolledOffsets (op.getMixedOffsets (), tdescTy,
197+ *targetShape, createOp, loc, rewriter);
171198 }
172199 Value castOp = unpack (newOps, tdescTy, *targetShape, loc, rewriter);
173200 rewriter.replaceOp (op, castOp);
@@ -216,17 +243,30 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
216243 return failure ();
217244
218245 int64_t offsetSize = static_cast <int64_t >(op.getOffsets ().size ());
219- if ((offsetSize != 0 ) || op.getConstOffsetsAttr ())
220- return failure ();
246+ bool hasOffsets = (offsetSize != 0 ) || op.getConstOffsetsAttr ();
247+
248+ SmallVector<Type> convertedTdescTypes = getUnrolledTypes (
249+ tdescTy, *targetShape, /* returnSingleType*/ hasOffsets);
221250
222- SmallVector<Type> convertedTdescTypes =
223- getUnrolledTypes (tdescTy, *targetShape);
224251 SmallVector<Value> convertedTdesc = pack (
225252 op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
226253
227- for (auto t : convertedTdesc)
228- xegpu::PrefetchNdOp::create (rewriter, loc, TypeRange (), t,
229- op->getAttrs ());
254+ if (!hasOffsets) {
255+ for (auto t : convertedTdesc)
256+ xegpu::PrefetchNdOp::create (rewriter, loc, TypeRange (), t,
257+ op->getAttrs ());
258+ } else {
259+ auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
260+ xegpu::PrefetchNdOp::create (rewriter, loc, convertedTdesc[0 ], offsets,
261+ op.getL1HintAttr (), op.getL2HintAttr (),
262+ op.getL3HintAttr ());
263+ // return dummy Value to satisfy function's signature
264+ return nullptr ;
265+ };
266+
267+ computeUnrolledOffsets (op.getMixedOffsets (), tdescTy, *targetShape,
268+ createPrefetch, loc, rewriter);
269+ }
230270
231271 rewriter.eraseOp (op);
232272 return success ();
@@ -247,22 +287,33 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
247287 return failure ();
248288
249289 int64_t offsetSize = static_cast <int64_t >(op.getOffsets ().size ());
250- if ((offsetSize != 0 ) || op.getConstOffsetsAttr ())
251- return failure ();
290+ bool hasOffsets = (offsetSize != 0 ) || op.getConstOffsetsAttr ();
252291
253292 Type elemTy = tdescTy.getElementType ();
254293 VectorType newValueTy = valueTy.cloneWith (*targetShape, elemTy);
255294
256- SmallVector<Type> convertedTdescTypes =
257- getUnrolledTypes (tdescTy, *targetShape);
295+ SmallVector<Type> convertedTdescTypes = getUnrolledTypes (
296+ tdescTy, *targetShape, /* returnSingleType*/ hasOffsets);
297+
258298 SmallVector<Value> convertedTdescs = pack (
259299 op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
260-
261300 SmallVector<Value> newOps;
262- for (auto t : convertedTdescs) {
263- auto newOp =
264- xegpu::LoadNdOp::create (rewriter, loc, newValueTy, t, op->getAttrs ());
265- newOps.push_back (newOp);
301+
302+ if (!hasOffsets) {
303+ for (auto t : convertedTdescs) {
304+ auto newOp = xegpu::LoadNdOp::create (rewriter, loc, newValueTy, t,
305+ op->getAttrs ());
306+ newOps.push_back (newOp);
307+ }
308+ } else {
309+ auto createLoad = [&](SmallVector<OpFoldResult> offsets) {
310+ return xegpu::LoadNdOp::create (
311+ rewriter, loc, newValueTy, convertedTdescs[0 ], offsets,
312+ op.getPackedAttr (), op.getTransposeAttr (), op.getL1HintAttr (),
313+ op.getL2HintAttr (), op.getL3HintAttr ());
314+ };
315+ newOps = computeUnrolledOffsets (op.getMixedOffsets (), tdescTy,
316+ *targetShape, createLoad, loc, rewriter);
266317 }
267318
268319 Value castOp = unpack (newOps, op.getType (), *targetShape, loc, rewriter);
@@ -285,22 +336,36 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
285336 return failure ();
286337
287338 int64_t offsetSize = static_cast <int64_t >(op.getOffsets ().size ());
288- if ((offsetSize != 0 ) || op.getConstOffsetsAttr ())
289- return failure ();
339+ bool hasOffsets = (offsetSize != 0 ) || op.getConstOffsetsAttr ();
290340
291341 SmallVector<Type> convertedValTypes =
292342 getUnrolledTypes (valueTy, *targetShape);
293- SmallVector<Type> convertedTdescTypes =
294- getUnrolledTypes ( tdescTy, *targetShape);
343+ SmallVector<Type> convertedTdescTypes = getUnrolledTypes (
344+ tdescTy, *targetShape, /* returnSingleType */ hasOffsets );
295345
296- SmallVector<Value> convertedValues =
297- pack (op.getValue (), convertedValTypes, *targetShape, loc, rewriter);
298346 SmallVector<Value> convertedTdescs = pack (
299347 op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
300348
301- for (auto [v, t] : llvm::zip (convertedValues, convertedTdescs))
302- xegpu::StoreNdOp::create (rewriter, loc, v, t, op.getL1HintAttr (),
303- op.getL2HintAttr (), op.getL3HintAttr ());
349+ SmallVector<Value> convertedValues =
350+ pack (op.getValue (), convertedValTypes, *targetShape, loc, rewriter);
351+ if (!hasOffsets) {
352+ for (auto [v, t] : llvm::zip (convertedValues, convertedTdescs))
353+ xegpu::StoreNdOp::create (rewriter, loc, v, t, op.getL1HintAttr (),
354+ op.getL2HintAttr (), op.getL3HintAttr ());
355+ } else {
356+ size_t valueIndex = 0 ;
357+ auto createStore = [&](SmallVector<OpFoldResult> offsets) {
358+ xegpu::StoreNdOp::create (rewriter, loc, convertedValues[valueIndex++],
359+ convertedTdescs[0 ], offsets,
360+ op.getL1HintAttr (), op.getL2HintAttr (),
361+ op.getL3HintAttr ());
362+ // return dummy Value to satisfy function's signature
363+ return nullptr ;
364+ };
365+
366+ computeUnrolledOffsets (op.getMixedOffsets (), tdescTy, *targetShape,
367+ createStore, loc, rewriter);
368+ }
304369
305370 rewriter.eraseOp (op);
306371 return success ();
0 commit comments