Skip to content

Commit 4a92953

Browse files
committed
[MLIR][Conversion] XeGPU to XeVM: Create nd tensor descriptor payload for base memory rank > 2
1 parent 925b106 commit 4a92953

File tree

1 file changed

+66
-16
lines changed

1 file changed

+66
-16
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,21 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
151151
}
152152
}
153153

154+
// Compute the product of sizes in the range [lo, hi) from the sizes array.
155+
static Value getProductOfSizes(ConversionPatternRewriter &rewriter,
156+
Location loc, ArrayRef<OpFoldResult> sizes,
157+
size_t lo, size_t hi) {
158+
Type indexTy = rewriter.getIndexType();
159+
Value product = arith::ConstantIndexOp::create(rewriter, loc, 1);
160+
for (size_t idx = lo; idx < hi; idx++) {
161+
OpFoldResult ofr = sizes[idx];
162+
Value sizeVal = getValueOrCreateConstantIntOp(rewriter, loc, ofr);
163+
sizeVal = getValueOrCreateCastToIndexLike(rewriter, loc, indexTy, sizeVal);
164+
product = rewriter.createOrFold<arith::MulIOp>(loc, product, sizeVal);
165+
}
166+
return product;
167+
}
168+
154169
class CreateNdDescToXeVMPattern
155170
: public OpConversionPattern<xegpu::CreateNdDescOp> {
156171
using OpConversionPattern::OpConversionPattern;
@@ -184,10 +199,9 @@ class CreateNdDescToXeVMPattern
184199

185200
// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
186201
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
187-
// Descriptor shape is expected to be 2D.
188-
int64_t rank = mixedSizes.size();
189-
if (rank != 2)
190-
return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
202+
auto srcRank = mixedSizes.size();
203+
if (srcRank < 2)
204+
return rewriter.notifyMatchFailure(op, "Expected at least 2D source.");
191205

192206
auto sourceTy = source.getType();
193207
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
@@ -203,18 +217,23 @@ class CreateNdDescToXeVMPattern
203217
baseAddr = adaptor.getSource();
204218
}
205219
// Utility for creating offset values from op fold result.
206-
auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
207-
unsigned idx) -> Value {
208-
Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
220+
auto createOffset = [&](OpFoldResult ofr) -> Value {
221+
Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofr);
209222
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
210223
return val;
211224
};
212225
// Offsets are not supported (0 is used).
213226
offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
214227
offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
215228
// Get shape values from op fold results.
216-
baseShapeW = createOffset(mixedSizes, 1);
217-
baseShapeH = createOffset(mixedSizes, 0);
229+
baseShapeW = createOffset(mixedSizes[srcRank - 1]);
230+
if (srcRank == 2) {
231+
baseShapeH = createOffset(mixedSizes[0]);
232+
} else {
233+
// Generate compute chain for height (product of sizes of all but the last
234+
// dimension).
235+
baseShapeH = getProductOfSizes(rewriter, loc, mixedSizes, 0, srcRank - 1);
236+
}
218237
if (sourceMemrefTy) {
219238
// Cast index to i64.
220239
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
@@ -255,10 +274,18 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
255274
LogicalResult
256275
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
257276
ConversionPatternRewriter &rewriter) const override {
277+
auto tdVal = op.getTensorDesc();
278+
xegpu::CreateNdDescOp descOp =
279+
tdVal.template getDefiningOp<xegpu::CreateNdDescOp>();
280+
auto mixedStrides = descOp.getMixedStrides();
258281
auto mixedOffsets = op.getMixedOffsets();
259-
int64_t opOffsetsSize = mixedOffsets.size();
260-
if (opOffsetsSize != 2)
261-
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
282+
auto mixedSizes = descOp.getMixedSizes();
283+
size_t opOffsetsSize = mixedOffsets.size();
284+
if (opOffsetsSize != mixedStrides.size())
285+
return rewriter.notifyMatchFailure(
286+
op, "Offsets size should match base memory rank.");
287+
if (opOffsetsSize < 2)
288+
return rewriter.notifyMatchFailure(op, "Expected at least 2D offset.");
262289
auto loc = op.getLoc();
263290
auto ctxt = rewriter.getContext();
264291

@@ -283,12 +310,35 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
283310
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
284311
// Offsets are provided by the op.
285312
// convert them to i32.
286-
Value offsetW =
287-
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
313+
// Offset computation assumes base memory layout is row major.
314+
Value offsetW = getValueOrCreateConstantIntOp(
315+
rewriter, loc, mixedOffsets[opOffsetsSize - 1]);
288316
offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
289317
rewriter.getI32Type(), offsetW);
290-
Value offsetH =
291-
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
318+
Value offsetH;
319+
if (opOffsetsSize == 2)
320+
offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
321+
else {
322+
offsetH = arith::ConstantIndexOp::create(rewriter, loc, 0);
323+
Value tmpStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
324+
// offsetH requires computing the linear offset using the strides.
325+
for (size_t idx = 0; idx < opOffsetsSize - 1; idx++) {
326+
size_t revIdx = opOffsetsSize - 2 - idx;
327+
Value offsetVal =
328+
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[revIdx]);
329+
offsetVal = getValueOrCreateCastToIndexLike(
330+
rewriter, loc, rewriter.getIndexType(), offsetVal);
331+
Value mul =
332+
rewriter.createOrFold<arith::MulIOp>(loc, tmpStride, offsetVal);
333+
Value dimSize =
334+
getValueOrCreateConstantIntOp(rewriter, loc, mixedSizes[revIdx]);
335+
dimSize = getValueOrCreateCastToIndexLike(
336+
rewriter, loc, rewriter.getIndexType(), dimSize);
337+
tmpStride =
338+
rewriter.createOrFold<arith::MulIOp>(loc, tmpStride, dimSize);
339+
offsetH = rewriter.createOrFold<arith::AddIOp>(loc, offsetH, mul);
340+
}
341+
}
292342
offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
293343
rewriter.getI32Type(), offsetH);
294344
// Get address space from tensor descriptor memory space.

0 commit comments

Comments
 (0)