Skip to content

Commit 87e094d

Browse files
authored
[MLIR][Conversion] XeGPU to XeVM: Add handler for 1D block ops (llvm#165894)
Add lowering for xegpu load_nd / store_nd with 1D tensor descriptor. Add conversion test case.
1 parent fc093f1 commit 87e094d

File tree

4 files changed

+205
-98
lines changed

4 files changed

+205
-98
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 165 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,6 @@ class CreateNdDescToXeVMPattern
186186
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
187187
// Descriptor shape is expected to be 2D.
188188
int64_t rank = mixedSizes.size();
189-
if (rank != 2)
190-
return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
191-
192189
auto sourceTy = source.getType();
193190
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
194191
// If source is a memref, we need to extract the aligned pointer as index.
@@ -199,8 +196,19 @@ class CreateNdDescToXeVMPattern
199196
}
200197
baseAddr =
201198
memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
199+
// Cast index to i64.
200+
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
202201
} else {
203202
baseAddr = adaptor.getSource();
203+
if (baseAddr.getType() != i64Ty) {
204+
// Pointer type may be i32. Cast to i64 if needed.
205+
baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
206+
}
207+
}
208+
// 1D tensor descriptor is just the base address.
209+
if (rank == 1) {
210+
rewriter.replaceOp(op, baseAddr);
211+
return success();
204212
}
205213
// Utility for creating offset values from op fold result.
206214
auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
@@ -215,13 +223,6 @@ class CreateNdDescToXeVMPattern
215223
// Get shape values from op fold results.
216224
baseShapeW = createOffset(mixedSizes, 1);
217225
baseShapeH = createOffset(mixedSizes, 0);
218-
if (sourceMemrefTy) {
219-
// Cast index to i64.
220-
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
221-
} else if (baseAddr.getType() != i64Ty) {
222-
// Pointer type may be i32. Cast to i64 if needed.
223-
baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
224-
}
225226
// Populate payload.
226227
Value payLoadAsI64 =
227228
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
@@ -257,108 +258,175 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
257258
ConversionPatternRewriter &rewriter) const override {
258259
auto mixedOffsets = op.getMixedOffsets();
259260
int64_t opOffsetsSize = mixedOffsets.size();
260-
if (opOffsetsSize != 2)
261-
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
262261
auto loc = op.getLoc();
263262
auto ctxt = rewriter.getContext();
264263

265264
auto tdesc = adaptor.getTensorDesc();
266265
auto tdescTy = op.getTensorDescType();
267-
if (tdescTy.getRank() != 2)
268-
return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
266+
auto tileRank = tdescTy.getRank();
267+
if (opOffsetsSize != tileRank)
268+
return rewriter.notifyMatchFailure(
269+
op, "Expected offset rank to match descriptor rank.");
269270
auto elemType = tdescTy.getElementType();
270271
auto elemBitSize = elemType.getIntOrFloatBitWidth();
271272
if (elemBitSize % 8 != 0)
272273
return rewriter.notifyMatchFailure(
273274
op, "Expected element type bit width to be multiple of 8.");
274275

275-
VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
276-
Value payLoadAsI64 =
277-
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
278-
Value basePtr = vector::ExtractOp::create(
279-
rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
280-
Value baseShapeW = vector::ExtractOp::create(
281-
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
282-
Value baseShapeH = vector::ExtractOp::create(
283-
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
284-
// Offsets are provided by the op.
285-
// convert them to i32.
286-
Value offsetW =
287-
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
288-
offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
289-
rewriter.getI32Type(), offsetW);
290-
Value offsetH =
291-
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
292-
offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
293-
rewriter.getI32Type(), offsetH);
294276
// Get address space from tensor descriptor memory space.
295277
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
296278
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
297-
// Convert base pointer (i64) to LLVM pointer type.
298-
Value basePtrLLVM =
299-
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
300-
// Compute element byte size and surface width in bytes.
301-
Value elemByteSize = arith::ConstantIntOp::create(
302-
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
303-
Value surfaceW =
304-
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
305-
306-
// Get tile sizes and vblocks from the tensor descriptor type.
307-
auto tileW = tdescTy.getDimSize(1);
308-
auto tileH = tdescTy.getDimSize(0);
309-
int32_t vblocks = tdescTy.getArrayLength();
310-
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
311-
Value src = adaptor.getValue();
312-
// If store value is a scalar, get value from op instead of adaptor.
313-
// Adaptor might have optimized away single element vector
314-
if (src.getType().isIntOrFloat()) {
315-
src = op.getValue();
316-
}
317-
VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
318-
if (!srcVecTy)
319-
return rewriter.notifyMatchFailure(
320-
op, "Expected store value to be a vector type.");
321-
// Get flat vector type of integer type with matching element bit size.
322-
VectorType newSrcVecTy =
323-
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
324-
if (srcVecTy != newSrcVecTy)
325-
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
326-
auto storeCacheControl =
327-
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
328-
xevm::BlockStore2dOp::create(
329-
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
330-
offsetH, elemBitSize, tileW, tileH, src,
331-
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
332-
rewriter.eraseOp(op);
333-
} else {
334-
auto loadCacheControl =
335-
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
336-
if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
337-
xevm::BlockPrefetch2dOp::create(
279+
if (tileRank == 2) {
280+
// Compute element byte size.
281+
Value elemByteSize = arith::ConstantIntOp::create(
282+
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
283+
VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
284+
Value payLoadAsI64 =
285+
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
286+
Value basePtr =
287+
vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
288+
static_cast<int>(NdTdescOffset::BasePtr));
289+
Value baseShapeW = vector::ExtractOp::create(
290+
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
291+
Value baseShapeH = vector::ExtractOp::create(
292+
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
293+
// Offsets are provided by the op.
294+
// convert them to i32.
295+
Value offsetW =
296+
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
297+
offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
298+
rewriter.getI32Type(), offsetW);
299+
Value offsetH =
300+
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
301+
offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
302+
rewriter.getI32Type(), offsetH);
303+
// Convert base pointer (i64) to LLVM pointer type.
304+
Value basePtrLLVM =
305+
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
306+
// Compute width in bytes.
307+
Value surfaceW =
308+
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
309+
310+
// Get tile width from the tensor descriptor type.
311+
auto tileW = tdescTy.getDimSize(tileRank - 1);
312+
// Get tile height from the tensor descriptor type.
313+
auto tileH = tdescTy.getDimSize(0);
314+
// Get vblocks from the tensor descriptor type.
315+
int32_t vblocks = tdescTy.getArrayLength();
316+
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
317+
Value src = adaptor.getValue();
318+
// If store value is a scalar, get value from op instead of adaptor.
319+
// Adaptor might have optimized away single element vector
320+
if (src.getType().isIntOrFloat()) {
321+
src = op.getValue();
322+
}
323+
VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
324+
if (!srcVecTy)
325+
return rewriter.notifyMatchFailure(
326+
op, "Expected store value to be a vector type.");
327+
// Get flat vector type of integer type with matching element bit size.
328+
VectorType newSrcVecTy =
329+
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
330+
if (srcVecTy != newSrcVecTy)
331+
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
332+
auto storeCacheControl =
333+
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
334+
xevm::BlockStore2dOp::create(
338335
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
339-
offsetH, elemBitSize, tileW, tileH, vblocks,
340-
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
336+
offsetH, elemBitSize, tileW, tileH, src,
337+
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
341338
rewriter.eraseOp(op);
342339
} else {
343-
VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
344-
const bool vnni = op.getPacked().value_or(false);
345-
auto transposeValue = op.getTranspose();
346-
bool transpose =
347-
transposeValue.has_value() && transposeValue.value()[0] == 1;
348-
VectorType loadedTy = encodeVectorTypeTo(
349-
dstVecTy, vnni ? rewriter.getI32Type()
350-
: rewriter.getIntegerType(elemBitSize));
351-
352-
Value resultFlatVec = xevm::BlockLoad2dOp::create(
353-
rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
354-
surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
355-
transpose, vnni,
340+
auto loadCacheControl =
341+
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
342+
if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
343+
xevm::BlockPrefetch2dOp::create(
344+
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
345+
offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
346+
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
347+
rewriter.eraseOp(op);
348+
} else {
349+
VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
350+
const bool vnni = op.getPacked().value_or(false);
351+
auto transposeValue = op.getTranspose();
352+
bool transpose =
353+
transposeValue.has_value() && transposeValue.value()[0] == 1;
354+
VectorType loadedTy = encodeVectorTypeTo(
355+
dstVecTy, vnni ? rewriter.getI32Type()
356+
: rewriter.getIntegerType(elemBitSize));
357+
358+
Value resultFlatVec = xevm::BlockLoad2dOp::create(
359+
rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
360+
surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
361+
transpose, vnni,
362+
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
363+
resultFlatVec = vector::BitCastOp::create(
364+
rewriter, loc,
365+
encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
366+
resultFlatVec);
367+
rewriter.replaceOp(op, resultFlatVec);
368+
}
369+
}
370+
} else {
371+
// 1D tensor descriptor.
372+
// `tdesc` represents base address as i64
373+
// Offset in number of elements, need to multiply by element byte size.
374+
// Compute byte offset.
375+
// byteOffset = offset * elementByteSize
376+
Value offset =
377+
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
378+
offset = getValueOrCreateCastToIndexLike(rewriter, loc,
379+
rewriter.getI64Type(), offset);
380+
// Compute element byte size.
381+
Value elemByteSize = arith::ConstantIntOp::create(
382+
rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
383+
Value byteOffset =
384+
rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
385+
// Final address = basePtr + byteOffset
386+
Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
387+
loc, tdesc,
388+
getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
389+
byteOffset));
390+
// Convert base pointer (i64) to LLVM pointer type.
391+
Value finalPtrLLVM =
392+
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
393+
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
394+
Value src = adaptor.getValue();
395+
// If store value is a scalar, get value from op instead of adaptor.
396+
// Adaptor might have optimized away single element vector
397+
if (src.getType().isIntOrFloat()) {
398+
src = op.getValue();
399+
}
400+
VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
401+
if (!srcVecTy)
402+
return rewriter.notifyMatchFailure(
403+
op, "Expected store value to be a vector type.");
404+
// Get flat vector type of integer type with matching element bit size.
405+
VectorType newSrcVecTy =
406+
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
407+
if (srcVecTy != newSrcVecTy)
408+
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
409+
auto storeCacheControl =
410+
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
411+
rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
412+
op, finalPtrLLVM, src,
413+
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
414+
} else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
415+
auto loadCacheControl =
416+
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
417+
VectorType resTy = cast<VectorType>(op.getValue().getType());
418+
VectorType loadedTy =
419+
encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
420+
Value load = xevm::BlockLoadOp::create(
421+
rewriter, loc, loadedTy, finalPtrLLVM,
356422
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
357-
resultFlatVec = vector::BitCastOp::create(
358-
rewriter, loc,
359-
encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
360-
resultFlatVec);
361-
rewriter.replaceOp(op, resultFlatVec);
423+
if (loadedTy != resTy)
424+
load = vector::BitCastOp::create(rewriter, loc, resTy, load);
425+
rewriter.replaceOp(op, load);
426+
} else {
427+
return rewriter.notifyMatchFailure(
428+
op, "Unsupported operation: xegpu.prefetch_nd with tensor "
429+
"descriptor rank == 1");
362430
}
363431
}
364432
return success();
@@ -929,7 +997,10 @@ struct ConvertXeGPUToXeVMPass
929997
return VectorType::get(sum, elemType);
930998
});
931999
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
1000+
// Scattered descriptors are not supported in XeVM lowering.
9321001
if (type.isScattered())
1002+
return {};
1003+
if (type.getRank() == 1)
9331004
return IntegerType::get(&getContext(), 64);
9341005
auto i32Type = IntegerType::get(&getContext(), 32);
9351006
return VectorType::get(8, i32Type);

mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ gpu.module @create_nd_tdesc {
2929

3030
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
3131
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
32+
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
3233
// CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
3334
// CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
3435
// CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
3536
// CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
3637
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
3738
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
38-
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
3939
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
4040
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
4141
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
@@ -53,11 +53,11 @@ gpu.module @create_nd_tdesc {
5353
%BLOCK_DMODEL = arith.constant 16 : index
5454
// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
5555
// CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
56+
// CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
5657
// CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
5758
// CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
5859
// CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
5960
// CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
60-
// CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
6161
// CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
6262
// CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
6363
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>

0 commit comments

Comments
 (0)