@@ -151,6 +151,21 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
151151 }
152152}
153153
154+ // Compute the product of sizes in the range [lo, hi) from the sizes array.
155+ static Value getProductOfSizes (ConversionPatternRewriter &rewriter,
156+ Location loc, ArrayRef<OpFoldResult> sizes,
157+ size_t lo, size_t hi) {
158+ Type indexTy = rewriter.getIndexType ();
159+ Value product = arith::ConstantIndexOp::create (rewriter, loc, 1 );
160+ for (size_t idx = lo; idx < hi; idx++) {
161+ OpFoldResult ofr = sizes[idx];
162+ Value sizeVal = getValueOrCreateConstantIntOp (rewriter, loc, ofr);
163+ sizeVal = getValueOrCreateCastToIndexLike (rewriter, loc, indexTy, sizeVal);
164+ product = rewriter.createOrFold <arith::MulIOp>(loc, product, sizeVal);
165+ }
166+ return product;
167+ }
168+
154169class CreateNdDescToXeVMPattern
155170 : public OpConversionPattern<xegpu::CreateNdDescOp> {
156171 using OpConversionPattern::OpConversionPattern;
@@ -184,10 +199,9 @@ class CreateNdDescToXeVMPattern
184199
185200 // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
186201 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes ();
187- // Descriptor shape is expected to be 2D.
188- int64_t rank = mixedSizes.size ();
189- if (rank != 2 )
190- return rewriter.notifyMatchFailure (op, " Expected 2D shape." );
202+ auto srcRank = mixedSizes.size ();
203+ if (srcRank < 2 )
204+ return rewriter.notifyMatchFailure (op, " Expected at least 2D source." );
191205
192206 auto sourceTy = source.getType ();
193207 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
@@ -203,18 +217,23 @@ class CreateNdDescToXeVMPattern
203217 baseAddr = adaptor.getSource ();
204218 }
205219 // Utility for creating offset values from op fold result.
206- auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
207- unsigned idx) -> Value {
208- Value val = getValueOrCreateConstantIntOp (rewriter, loc, ofrVec[idx]);
220+ auto createOffset = [&](OpFoldResult ofr) -> Value {
221+ Value val = getValueOrCreateConstantIntOp (rewriter, loc, ofr);
209222 val = getValueOrCreateCastToIndexLike (rewriter, loc, payloadElemTy, val);
210223 return val;
211224 };
212225 // Offsets are not supported (0 is used).
213226 offsetW = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
214227 offsetH = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
215228 // Get shape values from op fold results.
216- baseShapeW = createOffset (mixedSizes, 1 );
217- baseShapeH = createOffset (mixedSizes, 0 );
229+ baseShapeW = createOffset (mixedSizes[srcRank - 1 ]);
230+ if (srcRank == 2 ) {
231+ baseShapeH = createOffset (mixedSizes[0 ]);
232+ } else {
233+ // Generate compute chain for height (product of sizes of all but the last
234+ // dimension).
235+ baseShapeH = getProductOfSizes (rewriter, loc, mixedSizes, 0 , srcRank - 1 );
236+ }
218237 if (sourceMemrefTy) {
219238 // Cast index to i64.
220239 baseAddr = arith::IndexCastUIOp::create (rewriter, loc, i64Ty, baseAddr);
@@ -255,10 +274,18 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
255274 LogicalResult
256275 matchAndRewrite (OpType op, typename OpType::Adaptor adaptor,
257276 ConversionPatternRewriter &rewriter) const override {
277+ auto tdVal = op.getTensorDesc ();
278+ xegpu::CreateNdDescOp descOp =
279+ tdVal.template getDefiningOp <xegpu::CreateNdDescOp>();
280+ auto mixedStrides = descOp.getMixedStrides ();
258281 auto mixedOffsets = op.getMixedOffsets ();
259- int64_t opOffsetsSize = mixedOffsets.size ();
260- if (opOffsetsSize != 2 )
261- return rewriter.notifyMatchFailure (op, " Expected 2D offsets." );
282+ auto mixedSizes = descOp.getMixedSizes ();
283+ size_t opOffsetsSize = mixedOffsets.size ();
284+ if (opOffsetsSize != mixedStrides.size ())
285+ return rewriter.notifyMatchFailure (
286+ op, " Offsets size should match base memory rank." );
287+ if (opOffsetsSize < 2 )
288+ return rewriter.notifyMatchFailure (op, " Expected at least 2D offset." );
262289 auto loc = op.getLoc ();
263290 auto ctxt = rewriter.getContext ();
264291
@@ -283,12 +310,35 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
283310 rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeH));
284311 // Offsets are provided by the op.
285312 // convert them to i32.
286- Value offsetW =
287- getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[1 ]);
313+ // Offset computation assumes base memory layout is row major.
314+ Value offsetW = getValueOrCreateConstantIntOp (
315+ rewriter, loc, mixedOffsets[opOffsetsSize - 1 ]);
288316 offsetW = getValueOrCreateCastToIndexLike (rewriter, loc,
289317 rewriter.getI32Type (), offsetW);
290- Value offsetH =
291- getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[0 ]);
318+ Value offsetH;
319+ if (opOffsetsSize == 2 )
320+ offsetH = getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[0 ]);
321+ else {
322+ offsetH = arith::ConstantIndexOp::create (rewriter, loc, 0 );
323+ Value tmpStride = arith::ConstantIndexOp::create (rewriter, loc, 1 );
324+ // offsetH requires computing the linear offset using the strides.
325+ for (size_t idx = 0 ; idx < opOffsetsSize - 1 ; idx++) {
326+ size_t revIdx = opOffsetsSize - 2 - idx;
327+ Value offsetVal =
328+ getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[revIdx]);
329+ offsetVal = getValueOrCreateCastToIndexLike (
330+ rewriter, loc, rewriter.getIndexType (), offsetVal);
331+ Value mul =
332+ rewriter.createOrFold <arith::MulIOp>(loc, tmpStride, offsetVal);
333+ Value dimSize =
334+ getValueOrCreateConstantIntOp (rewriter, loc, mixedSizes[revIdx]);
335+ dimSize = getValueOrCreateCastToIndexLike (
336+ rewriter, loc, rewriter.getIndexType (), dimSize);
337+ tmpStride =
338+ rewriter.createOrFold <arith::MulIOp>(loc, tmpStride, dimSize);
339+ offsetH = rewriter.createOrFold <arith::AddIOp>(loc, offsetH, mul);
340+ }
341+ }
292342 offsetH = getValueOrCreateCastToIndexLike (rewriter, loc,
293343 rewriter.getI32Type (), offsetH);
294344 // Get address space from tensor descriptor memory space.
0 commit comments