Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ template <typename T>
int getLargestDivisor(T dim, ArrayRef<T> candidates,
ArrayRef<T> candidateMultiples = {});

/// Checks if the given MemRefType refers to shared memory.
bool isSharedMemRef(const MemRefType &memrefTy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This is only used in a single file for now (XeGPUToXeVM). For now you can have this as a helper inside that file. If more uses arise we can have a common place for it.


} // namespace xegpu

} // namespace mlir
Expand Down
59 changes: 51 additions & 8 deletions mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -991,27 +991,70 @@ struct ConvertXeGPUToXeVMPass
});

typeConverter.addConversion([&](MemRefType type) -> Type {
if (type.getMemorySpaceAsInt() == 3)
return IntegerType::get(&getContext(), 32);
return IntegerType::get(&getContext(), 64);
return IntegerType::get(&getContext(),
(xegpu::isSharedMemRef(type) ? 32 : 64));
});

// LLVM type converter puts unrealized casts for the following cases:
// add materialization casts to handle them.

// Materialization to convert memref to i64
// Materialization to convert memref to i64 or i32 depending on global/SLM
auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
unsigned rank = memrefTy.getRank();
Type indexType = builder.getIndexType();

Value addr =
memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
return arith::IndexCastUIOp::create(builder, loc, type, addr)
.getResult();
int64_t intOffsets;
SmallVector<int64_t> intStrides;
Value addr;
Value offset;
if (failed(memrefTy.getStridesAndOffset(intStrides, intOffsets))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this check? why not use ExtractStridedMetadataOp for all cases. Above case could return ShapedType::kDynamic for dynamic values.


// Result types: [base_memref, offset, stride0, stride1, ...,
// strideN-1, size0, size1, ..., sizeN-1]
SmallVector<Type> resultTypes{
MemRefType::get({}, memrefTy.getElementType(),
MemRefLayoutAttrInterface(),
memrefTy.getMemorySpace()),
indexType};
// strides + sizes
resultTypes.append(2 * rank, indexType);

auto meta = memref::ExtractStridedMetadataOp::create(
builder, loc, resultTypes, input);

addr = memref::ExtractAlignedPointerAsIndexOp::create(
builder, loc, meta.getBaseBuffer());
offset = meta.getOffset();

} else {
addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
input);
offset = arith::ConstantOp::create(builder, loc,
builder.getIndexAttr(intOffsets));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be kDynamic which is a special value? so this code is not correct.

}

auto addr_casted =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLVM does not use snake case variable naming. rename to addrCasted

arith::IndexCastUIOp::create(builder, loc, type, addr);
auto offset_casted =
arith::IndexCastUIOp::create(builder, loc, type, offset);

// Compute the final address: base address + byte offset
auto byte_size = arith::ConstantOp::create(
builder, loc, type,
builder.getIntegerAttr(type,
memrefTy.getElementTypeBitWidth() / 8));
auto byte_offset =
arith::MulIOp::create(builder, loc, offset_casted, byte_size);
auto addr_with_offset =
arith::AddIOp::create(builder, loc, addr_casted, byte_offset);

return addr_with_offset.getResult();
}
return {};
};
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,3 +579,15 @@ template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
template int
xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
ArrayRef<unsigned> candidateMultiples);

/// Checks if the given MemRefType refers to shared memory.
bool xegpu::isSharedMemRef(const MemRefType &memrefTy) {
Attribute attr = memrefTy.getMemorySpace();
if (!attr)
return false;
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
}
8 changes: 5 additions & 3 deletions mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ gpu.module @create_nd_tdesc {
gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
%stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
// CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
// CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
// CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64

// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
Expand Down Expand Up @@ -39,7 +40,7 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2_OFFSET:.*]], %[[VAR14]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
Expand All @@ -53,13 +54,14 @@ gpu.module @create_nd_tdesc {
%size_x = arith.constant 64 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
%BLOCK_DMODEL = arith.constant 16 : index

// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
// CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
// CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
// CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
// CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR25:.*]] = vector.insert %[[BASE_ADDR3_OFFSET:.*]], %[[VAR24]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
Expand Down
Loading