Skip to content

Commit 8adcf0a

Browse files
authored
[MLIR][XeGPU] Support subview memref: handling the base address during xegpu to xevm type conversion (llvm#170541)
During the XeGPU-to-XeVM type conversion, a memref is lowered to its base address. This PR extends the conversion to correctly handle memrefs that include an offset, such as those generated by memref.subview.
1 parent 3a0c006 commit 8adcf0a

File tree

4 files changed

+220
-65
lines changed

4 files changed

+220
-65
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,18 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
6666
llvm_unreachable("Unknown XeGPU memory space");
6767
}
6868

69+
/// Checks if the given MemRefType refers to shared memory.
70+
static bool isSharedMemRef(const MemRefType &memrefTy) {
71+
Attribute attr = memrefTy.getMemorySpace();
72+
if (!attr)
73+
return false;
74+
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
75+
return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
76+
if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
77+
return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
78+
return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
79+
}
80+
6981
// Get same bitwidth flat vector type of new element type.
7082
static VectorType encodeVectorTypeTo(VectorType currentVecType,
7183
Type toElemType) {
@@ -1066,27 +1078,69 @@ struct ConvertXeGPUToXeVMPass
10661078
});
10671079

10681080
typeConverter.addConversion([&](MemRefType type) -> Type {
1069-
if (type.getMemorySpaceAsInt() == 3)
1070-
return IntegerType::get(&getContext(), 32);
1071-
return IntegerType::get(&getContext(), 64);
1081+
return IntegerType::get(&getContext(), (isSharedMemRef(type) ? 32 : 64));
10721082
});
10731083

10741084
// LLVM type converter puts unrealized casts for the following cases:
10751085
// add materialization casts to handle them.
10761086

1077-
// Materialization to convert memref to i64
1087+
// Materialization to convert memref to i64 or i32 depending on global/SLM
10781088
auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
10791089
ValueRange inputs,
10801090
Location loc) -> Value {
10811091
if (inputs.size() != 1)
10821092
return {};
10831093
auto input = inputs.front();
10841094
if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1095+
unsigned rank = memrefTy.getRank();
1096+
Type indexType = builder.getIndexType();
10851097

1086-
Value addr =
1087-
memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
1088-
return arith::IndexCastUIOp::create(builder, loc, type, addr)
1089-
.getResult();
1098+
int64_t intOffsets;
1099+
SmallVector<int64_t> intStrides;
1100+
Value addr;
1101+
Value offset;
1102+
if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
1103+
ShapedType::isStatic(intOffsets)) {
1104+
addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
1105+
input);
1106+
offset = arith::ConstantOp::create(builder, loc,
1107+
builder.getIndexAttr(intOffsets));
1108+
} else {
1109+
1110+
// Result types: [base_memref, offset, stride0, stride1, ...,
1111+
// strideN-1, size0, size1, ..., sizeN-1]
1112+
SmallVector<Type> resultTypes{
1113+
MemRefType::get({}, memrefTy.getElementType(),
1114+
MemRefLayoutAttrInterface(),
1115+
memrefTy.getMemorySpace()),
1116+
indexType};
1117+
// strides + sizes
1118+
resultTypes.append(2 * rank, indexType);
1119+
1120+
auto meta = memref::ExtractStridedMetadataOp::create(
1121+
builder, loc, resultTypes, input);
1122+
1123+
addr = memref::ExtractAlignedPointerAsIndexOp::create(
1124+
builder, loc, meta.getBaseBuffer());
1125+
offset = meta.getOffset();
1126+
}
1127+
1128+
auto addrCasted =
1129+
arith::IndexCastUIOp::create(builder, loc, type, addr);
1130+
auto offsetCasted =
1131+
arith::IndexCastUIOp::create(builder, loc, type, offset);
1132+
1133+
// Compute the final address: base address + byte offset
1134+
auto byteSize = arith::ConstantOp::create(
1135+
builder, loc, type,
1136+
builder.getIntegerAttr(type,
1137+
memrefTy.getElementTypeBitWidth() / 8));
1138+
auto byteOffset =
1139+
arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
1140+
auto addrWithOffset =
1141+
arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
1142+
1143+
return addrWithOffset.getResult();
10901144
}
10911145
return {};
10921146
};

mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ gpu.module @create_nd_tdesc {
3737
// CHECK: %[[C32_I64_2:.*]] = arith.constant 32 : i64
3838
// CHECK: %[[PITCH2:.*]] = arith.trunci %[[C32_I64_2]] : i64 to i32
3939
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
40-
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
40+
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2_OFFSET:.*]], %[[VAR14]] [0] : i64 into vector<4xi64>
4141
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
4242
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
4343
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
@@ -55,7 +55,7 @@ gpu.module @create_nd_tdesc {
5555
// CHECK: %[[SHAPE_H3:.*]] = arith.index_cast %[[C64]] : index to i32
5656
// CHECK: %[[PITCH3:.*]] = arith.index_cast %[[C16]] : index to i32
5757
// CHECK: %[[VAR25:.*]] = vector.bitcast %[[CST_3]] : vector<8xi32> to vector<4xi64>
58-
// CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR]], %[[VAR25]] [0] : i64 into vector<4xi64>
58+
// CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR_OFFSET:.*]], %[[VAR25]] [0] : i64 into vector<4xi64>
5959
// CHECK: %[[VAR27:.*]] = vector.bitcast %[[VAR26]] : vector<4xi64> to vector<8xi32>
6060
// CHECK: %[[VAR28:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR27]] [2] : i32 into vector<8xi32>
6161
// CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR28]] [3] : i32 into vector<8xi32>

0 commit comments

Comments
 (0)