@@ -66,6 +66,18 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
6666 llvm_unreachable (" Unknown XeGPU memory space" );
6767}
6868
69+ // / Checks if the given MemRefType refers to shared memory.
70+ static bool isSharedMemRef (const MemRefType &memrefTy) {
71+ Attribute attr = memrefTy.getMemorySpace ();
72+ if (!attr)
73+ return false ;
74+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
75+ return intAttr.getInt () == static_cast <int >(xevm::AddrSpace::SHARED);
76+ if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
77+ return xevmSpace.getValue () == xevm::AddrSpace::SHARED;
78+ return gpu::GPUDialect::isWorkgroupMemoryAddressSpace (attr);
79+ }
80+
6981// Get same bitwidth flat vector type of new element type.
7082static VectorType encodeVectorTypeTo (VectorType currentVecType,
7183 Type toElemType) {
@@ -1066,27 +1078,69 @@ struct ConvertXeGPUToXeVMPass
10661078 });
10671079
10681080 typeConverter.addConversion ([&](MemRefType type) -> Type {
1069- if (type.getMemorySpaceAsInt () == 3 )
1070- return IntegerType::get (&getContext (), 32 );
1071- return IntegerType::get (&getContext (), 64 );
1081+ return IntegerType::get (&getContext (), (isSharedMemRef (type) ? 32 : 64 ));
10721082 });
10731083
10741084 // LLVM type converter puts unrealized casts for the following cases:
10751085 // add materialization casts to handle them.
10761086
1077- // Materialization to convert memref to i64
1087+ // Materialization to convert memref to i64 or i32 depending on global/SLM
10781088 auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
10791089 ValueRange inputs,
10801090 Location loc) -> Value {
10811091 if (inputs.size () != 1 )
10821092 return {};
10831093 auto input = inputs.front ();
10841094 if (auto memrefTy = dyn_cast<MemRefType>(input.getType ())) {
1095+ unsigned rank = memrefTy.getRank ();
1096+ Type indexType = builder.getIndexType ();
10851097
1086- Value addr =
1087- memref::ExtractAlignedPointerAsIndexOp::create (builder, loc, input);
1088- return arith::IndexCastUIOp::create (builder, loc, type, addr)
1089- .getResult ();
1098+ int64_t intOffsets;
1099+ SmallVector<int64_t > intStrides;
1100+ Value addr;
1101+ Value offset;
1102+ if (succeeded (memrefTy.getStridesAndOffset (intStrides, intOffsets)) &&
1103+ ShapedType::isStatic (intOffsets)) {
1104+ addr = memref::ExtractAlignedPointerAsIndexOp::create (builder, loc,
1105+ input);
1106+ offset = arith::ConstantOp::create (builder, loc,
1107+ builder.getIndexAttr (intOffsets));
1108+ } else {
1109+
1110+ // Result types: [base_memref, offset, stride0, stride1, ...,
1111+ // strideN-1, size0, size1, ..., sizeN-1]
1112+ SmallVector<Type> resultTypes{
1113+ MemRefType::get ({}, memrefTy.getElementType (),
1114+ MemRefLayoutAttrInterface (),
1115+ memrefTy.getMemorySpace ()),
1116+ indexType};
1117+ // strides + sizes
1118+ resultTypes.append (2 * rank, indexType);
1119+
1120+ auto meta = memref::ExtractStridedMetadataOp::create (
1121+ builder, loc, resultTypes, input);
1122+
1123+ addr = memref::ExtractAlignedPointerAsIndexOp::create (
1124+ builder, loc, meta.getBaseBuffer ());
1125+ offset = meta.getOffset ();
1126+ }
1127+
1128+ auto addrCasted =
1129+ arith::IndexCastUIOp::create (builder, loc, type, addr);
1130+ auto offsetCasted =
1131+ arith::IndexCastUIOp::create (builder, loc, type, offset);
1132+
1133+ // Compute the final address: base address + byte offset
1134+ auto byteSize = arith::ConstantOp::create (
1135+ builder, loc, type,
1136+ builder.getIntegerAttr (type,
1137+ memrefTy.getElementTypeBitWidth () / 8 ));
1138+ auto byteOffset =
1139+ arith::MulIOp::create (builder, loc, offsetCasted, byteSize);
1140+ auto addrWithOffset =
1141+ arith::AddIOp::create (builder, loc, addrCasted, byteOffset);
1142+
1143+ return addrWithOffset.getResult ();
10901144 }
10911145 return {};
10921146 };
0 commit comments