Skip to content

Commit 2664dc1

Browse files
committed
add static offset support
1 parent 43d9ddb commit 2664dc1

File tree

6 files changed

+146
-121
lines changed

6 files changed

+146
-121
lines changed

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ template <typename T>
175175
int getLargestDivisor(T dim, ArrayRef<T> candidates,
176176
ArrayRef<T> candidateMultiples = {});
177177

178+
/// Checks if the given MemRefType refers to shared memory.
179+
bool isSharedMemRef(const MemRefType &memrefTy);
180+
178181
} // namespace xegpu
179182

180183
} // namespace mlir

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -991,9 +991,8 @@ struct ConvertXeGPUToXeVMPass
991991
});
992992

993993
typeConverter.addConversion([&](MemRefType type) -> Type {
994-
if (type.getMemorySpaceAsInt() == 3)
995-
return IntegerType::get(&getContext(), 32);
996-
return IntegerType::get(&getContext(), 64);
994+
return IntegerType::get(&getContext(),
995+
(xegpu::isSharedMemRef(type) ? 32 : 64));
997996
});
998997

999998
// LLVM type converter puts unrealized casts for the following cases:
@@ -1010,24 +1009,35 @@ struct ConvertXeGPUToXeVMPass
10101009
unsigned rank = memrefTy.getRank();
10111010
Type indexType = builder.getIndexType();
10121011

1013-
SmallVector<Type> resultTypes;
1014-
// Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
1015-
// size0, size1, ..., sizeN-1]
1016-
resultTypes.push_back(MemRefType::get(
1017-
{}, memrefTy.getElementType(), MemRefLayoutAttrInterface(),
1018-
memrefTy.getMemorySpace())); // base memref (unranked)
1019-
resultTypes.push_back(indexType); // offset
1020-
for (unsigned i = 0; i < rank; ++i)
1021-
resultTypes.push_back(indexType); // strides
1022-
for (unsigned i = 0; i < rank; ++i)
1023-
resultTypes.push_back(indexType); // sizes
1024-
1025-
auto meta = memref::ExtractStridedMetadataOp::create(
1026-
builder, loc, resultTypes, input);
1027-
1028-
auto addr = memref::ExtractAlignedPointerAsIndexOp::create(
1029-
builder, loc, meta.getBaseBuffer());
1030-
auto offset = meta.getOffset();
1012+
int64_t intOffsets;
1013+
SmallVector<int64_t> intStrides;
1014+
Value addr;
1015+
Value offset;
1016+
if (failed(memrefTy.getStridesAndOffset(intStrides, intOffsets))) {
1017+
1018+
// Result types: [base_memref, offset, stride0, stride1, ...,
1019+
// strideN-1, size0, size1, ..., sizeN-1]
1020+
SmallVector<Type> resultTypes{
1021+
MemRefType::get({}, memrefTy.getElementType(),
1022+
MemRefLayoutAttrInterface(),
1023+
memrefTy.getMemorySpace()),
1024+
indexType};
1025+
// strides + sizes
1026+
resultTypes.append(2 * rank, indexType);
1027+
1028+
auto meta = memref::ExtractStridedMetadataOp::create(
1029+
builder, loc, resultTypes, input);
1030+
1031+
addr = memref::ExtractAlignedPointerAsIndexOp::create(
1032+
builder, loc, meta.getBaseBuffer());
1033+
offset = meta.getOffset();
1034+
1035+
} else {
1036+
addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
1037+
input);
1038+
offset = arith::ConstantOp::create(builder, loc,
1039+
builder.getIndexAttr(intOffsets));
1040+
}
10311041

10321042
auto addr_casted =
10331043
arith::IndexCastUIOp::create(builder, loc, type, addr);

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,3 +580,15 @@ template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
580580
template int
581581
xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
582582
ArrayRef<unsigned> candidateMultiples);
583+
584+
/// Checks if the given MemRefType refers to shared memory.
585+
bool xegpu::isSharedMemRef(const MemRefType &memrefTy) {
586+
Attribute attr = memrefTy.getMemorySpace();
587+
if (!attr)
588+
return false;
589+
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
590+
return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
591+
if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
592+
return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
593+
return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
594+
}

mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ gpu.module @create_nd_tdesc {
88
gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
99
%stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
1010
// CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
11-
// CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
11+
// CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
12+
1213
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
1314
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
1415
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
@@ -39,7 +40,7 @@ gpu.module @create_nd_tdesc {
3940
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
4041
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
4142
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
42-
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
43+
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2_OFFSET:.*]], %[[VAR14]] [0] : i64 into vector<4xi64>
4344
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
4445
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
4546
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
@@ -53,13 +54,14 @@ gpu.module @create_nd_tdesc {
5354
%size_x = arith.constant 64 : index
5455
// CHECK: %[[C16:.*]] = arith.constant 16 : index
5556
%BLOCK_DMODEL = arith.constant 16 : index
57+
5658
// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
5759
// CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
5860
// CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
5961
// CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
6062
// CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
6163
// CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
62-
// CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
64+
// CHECK: %[[VAR25:.*]] = vector.insert %[[BASE_ADDR3_OFFSET:.*]], %[[VAR24]] [0] : i64 into vector<4xi64>
6365
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
6466
// CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
6567
// CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>

0 commit comments

Comments
 (0)