Skip to content

Commit 1b9b790

Browse files
authored
[MLIR][Conversion] Convert XeGPU to XeVM pass: Remove lowering support for tensor descriptor with offsets. (#157550)
And update load/store/prefetch test cases to use direct offsets. Tensor descriptors with offsets are getting deprecated.
1 parent f2d827c commit 1b9b790

File tree

5 files changed

+85
-424
lines changed

5 files changed

+85
-424
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 34 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ class CreateNdDescToXeVMPattern
154154
matchAndRewrite(xegpu::CreateNdDescOp op,
155155
xegpu::CreateNdDescOp::Adaptor adaptor,
156156
ConversionPatternRewriter &rewriter) const override {
157+
SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
158+
if (mixedOffsets.size() != 0)
159+
return rewriter.notifyMatchFailure(op, "Offsets not supported.");
157160
auto loc = op.getLoc();
158161
auto source = op.getSource();
159162
// Op is lowered to a code sequence that populates payload.
@@ -177,7 +180,6 @@ class CreateNdDescToXeVMPattern
177180

178181
// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
179182
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
180-
SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
181183
// Descriptor shape is expected to be 2D.
182184
int64_t rank = mixedSizes.size();
183185
if (rank != 2)
@@ -202,17 +204,9 @@ class CreateNdDescToXeVMPattern
202204
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
203205
return val;
204206
};
205-
// Offsets can be either 2D or not provided (0 is used).
206-
if (mixedOffsets.size() == 2) {
207-
offsetW = createOffset(mixedOffsets, 1);
208-
offsetH = createOffset(mixedOffsets, 0);
209-
} else if (mixedOffsets.size() == 0) {
210-
offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
211-
offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
212-
} else {
213-
return rewriter.notifyMatchFailure(op,
214-
"Expected 2D offsets or no offsets.");
215-
}
207+
// Offsets are not supported (0 is used).
208+
offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
209+
offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
216210
// Get shape values from op fold results.
217211
baseShapeW = createOffset(mixedSizes, 1);
218212
baseShapeH = createOffset(mixedSizes, 0);
@@ -247,39 +241,6 @@ class CreateNdDescToXeVMPattern
247241
}
248242
};
249243

250-
class UpdateNdOffsetToXeVMPattern
251-
: public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
252-
using OpConversionPattern::OpConversionPattern;
253-
LogicalResult
254-
matchAndRewrite(xegpu::UpdateNdOffsetOp op,
255-
xegpu::UpdateNdOffsetOp::Adaptor adaptor,
256-
ConversionPatternRewriter &rewriter) const override {
257-
auto loc = op.getLoc();
258-
auto mixedOffsets = op.getMixedOffsets();
259-
// Only 2D offsets are supported for now.
260-
if (mixedOffsets.size() != 2)
261-
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
262-
auto payload = adaptor.getTensorDesc();
263-
// Utility for updating payload offset values from op fold result.
264-
auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
265-
Value offset =
266-
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
267-
offset = getValueOrCreateCastToIndexLike(rewriter, loc,
268-
rewriter.getI32Type(), offset);
269-
Value oldOffset =
270-
vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
271-
Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
272-
return vector::InsertOp::create(rewriter, loc, newOffset, payload,
273-
payloadPos);
274-
};
275-
// Update offsets in the payload.
276-
payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
277-
payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
278-
rewriter.replaceOp(op, payload);
279-
return success();
280-
}
281-
};
282-
283244
template <
284245
typename OpType,
285246
typename = std::enable_if_t<llvm::is_one_of<
@@ -289,6 +250,10 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
289250
LogicalResult
290251
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
291252
ConversionPatternRewriter &rewriter) const override {
253+
auto mixedOffsets = op.getMixedOffsets();
254+
int64_t opOffsetsSize = mixedOffsets.size();
255+
if (opOffsetsSize != 2)
256+
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
292257
auto loc = op.getLoc();
293258
auto ctxt = rewriter.getContext();
294259

@@ -311,32 +276,16 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
311276
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
312277
Value baseShapeH = vector::ExtractOp::create(
313278
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
314-
// Offsets provided in two ways:
315-
// 1. Offsets are extracted from the tensor descriptor.
316-
// 2. (Mixed) offsets which are provided by the op.
317-
Value offsetW;
318-
Value offsetH;
319-
auto mixedOffsets = op.getMixedOffsets();
320-
int64_t opOffsetsSize = mixedOffsets.size();
321-
if (opOffsetsSize != 0 && opOffsetsSize != 2)
322-
return rewriter.notifyMatchFailure(op,
323-
"Expected 2D offsets or no offsets.");
324-
if (opOffsetsSize) {
325-
// If mixed offsets are provided by the op convert them to i32.
326-
offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
327-
offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
328-
rewriter.getI32Type(), offsetW);
329-
offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
330-
offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
331-
rewriter.getI32Type(), offsetH);
332-
} else {
333-
// If offsets are not available, we need to extract them from the tensor
334-
// descriptor.
335-
offsetW = vector::ExtractOp::create(
336-
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW));
337-
offsetH = vector::ExtractOp::create(
338-
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH));
339-
}
279+
// Offsets are provided by the op.
280+
// convert them to i32.
281+
Value offsetW =
282+
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
283+
offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
284+
rewriter.getI32Type(), offsetW);
285+
Value offsetH =
286+
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
287+
offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
288+
rewriter.getI32Type(), offsetH);
340289
// Get address space from tensor descriptor memory space.
341290
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
342291
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
@@ -422,54 +371,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
422371
return newAddr;
423372
}
424373

425-
class CreateDescToXeVMPattern
426-
: public OpConversionPattern<xegpu::CreateDescOp> {
427-
using OpConversionPattern::OpConversionPattern;
428-
LogicalResult
429-
matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
430-
ConversionPatternRewriter &rewriter) const override {
431-
auto eTy = op.getTensorDescType().getElementType();
432-
auto eBw = eTy.getIntOrFloatBitWidth();
433-
if (eBw % 8 != 0)
434-
return rewriter.notifyMatchFailure(
435-
op, "Expected element type bit width to be multiple of 8.");
436-
auto loc = op.getLoc();
437-
// Offsets are provided as scalar i64 by type converter.
438-
auto offsets = adaptor.getOffsets();
439-
// Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
440-
// But type converter will convert them to integer types.
441-
Value addr = adaptor.getSource();
442-
// ui32 or i32 are passed as i32 so they need to be casted to i64.
443-
if (addr.getType() != rewriter.getI64Type())
444-
addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr);
445-
auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8);
446-
rewriter.replaceOp(op, laneAddr);
447-
return success();
448-
}
449-
};
450-
451-
class UpdateOffsetToXeVMPattern
452-
: public OpConversionPattern<xegpu::UpdateOffsetOp> {
453-
using OpConversionPattern::OpConversionPattern;
454-
LogicalResult
455-
matchAndRewrite(xegpu::UpdateOffsetOp op,
456-
xegpu::UpdateOffsetOp::Adaptor adaptor,
457-
ConversionPatternRewriter &rewriter) const override {
458-
auto eTy = op.getTensorDescType().getElementType();
459-
auto eBw = eTy.getIntOrFloatBitWidth();
460-
if (eBw % 8 != 0)
461-
return rewriter.notifyMatchFailure(
462-
op, "Expected element type bit width to be multiple of 8.");
463-
auto loc = op.getLoc();
464-
// Scatter descriptor is provided as scalar i64 by type converter.
465-
// Offsets are provided as scalar i64 by type converter.
466-
Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
467-
adaptor.getOffsets(), eBw / 8);
468-
rewriter.replaceOp(op, newOffset);
469-
return success();
470-
}
471-
};
472-
473374
template <typename OpType,
474375
typename = std::enable_if_t<llvm::is_one_of<
475376
OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
@@ -478,6 +379,9 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
478379
LogicalResult
479380
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
480381
ConversionPatternRewriter &rewriter) const override {
382+
Value offset = adaptor.getOffsets();
383+
if (!offset)
384+
return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
481385
auto loc = op.getLoc();
482386
auto ctxt = rewriter.getContext();
483387
auto tdescTy = op.getTensorDescType();
@@ -527,21 +431,16 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
527431
basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
528432
basePtrI64);
529433
}
530-
Value offsets = adaptor.getOffsets();
531434
Value mask = adaptor.getMask();
532-
if (offsets) {
533-
if (dyn_cast<VectorType>(offsets.getType())) {
534-
// Offset needs be scalar. Single element vector is converted to scalar
535-
// by type converter.
536-
return rewriter.notifyMatchFailure(op,
537-
"Expected offsets to be a scalar.");
538-
} else {
539-
// If offsets are provided, we add them to the base pointer.
540-
// Offsets are in number of elements, we need to multiply by
541-
// element byte size.
542-
basePtrI64 =
543-
addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
544-
}
435+
if (dyn_cast<VectorType>(offset.getType())) {
436+
// Offset needs be scalar. Single element vector is converted to scalar
437+
// by type converter.
438+
return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar.");
439+
} else {
440+
// If offset is provided, we add them to the base pointer.
441+
// Offset is in number of elements, we need to multiply by
442+
// element byte size.
443+
basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize);
545444
}
546445
// Convert base pointer (i64) to LLVM pointer type.
547446
Value basePtrLLVM =
@@ -1011,13 +910,12 @@ struct ConvertXeGPUToXeVMPass
1011910
//===----------------------------------------------------------------------===//
1012911
void mlir::populateXeGPUToXeVMConversionPatterns(
1013912
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1014-
patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
913+
patterns.add<CreateNdDescToXeVMPattern,
1015914
LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1016915
LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1017916
LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1018917
typeConverter, patterns.getContext());
1019-
patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern,
1020-
AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
918+
patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1021919
LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1022920
LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1023921
typeConverter, patterns.getContext());

mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -43,38 +43,6 @@ gpu.module @create_nd_tdesc {
4343
// CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
4444
// CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
4545
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
46-
47-
// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
48-
// CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
49-
// CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
50-
// CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
51-
// CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64
52-
// CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32
53-
// CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64
54-
// CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32
55-
// CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
56-
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
57-
// CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
58-
// CHECK: %[[VAR28:.*]] = vector.bitcast %[[VAR27]] : vector<4xi64> to vector<8xi32>
59-
// CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR28]] [2] : i32 into vector<8xi32>
60-
// CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
61-
// CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
62-
// CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
63-
%src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
64-
65-
// CHECK: %[[C8:.*]] = arith.constant 8 : index
66-
%c8 = arith.constant 8 : index
67-
// CHECK: %[[C16:.*]] = arith.constant 16 : index
68-
%c16 = arith.constant 16 : index
69-
// CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
70-
// CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
71-
// CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
72-
// CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
73-
// CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
74-
// CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
75-
// CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
76-
// CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
77-
%updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32>
7846
gpu.return
7947
}
8048
}

0 commit comments

Comments
 (0)