Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1282,12 +1282,6 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou
let hasCanonicalizer = 1;
}

def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">;
class StaticShared1DMemRefOf<list<Type> allowedTypes> :
ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred],
"statically shaped " # MemRefOf<allowedTypes>.summary # " for shared memory",
"mlir::MemRefType">;

class SizeInBits<string name> :
StrFunc<"llvm::cast<mlir::ShapedType>($" # name # ".getType()).getNumElements()"
"*llvm::cast<mlir::ShapedType>($" # name # ".getType()).getElementTypeBitWidth()">;
Expand All @@ -1304,11 +1298,12 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
as the underlying shared local memory.

Arguments:
- `source` : a 1D statically shaped memref with element type i8, representing the raw SLM buffer.
- `source` : 1D or 2D statically shape memref, representing the raw SLM buffer.
The provided memref must be contiguous.
Results:
- `mem_desc` : the memory descriptor.
}];
let arguments = (ins StaticShared1DMemRefOf<[I8]>:$source);
let arguments = (ins AnyTypeOf<[StaticShared1DMemRefOf<[XeGPU_ScalarType]>, StaticShared2DMemRefOf<[XeGPU_ScalarType]>]>:$source);
let results = (outs XeGPU_MemDesc:$mem_desc);
let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))";
}
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
let mnemonic = typeMnemonic;
}

def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">;
class StaticShared1DMemRefOf<list<Type> allowedTypes> :
ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred],
"reside in share memory and statically 1d shaped " # MemRefOf<allowedTypes>.summary # " ",
"mlir::MemRefType">;

class StaticShared2DMemRefOf<list<Type> allowedTypes>:
ConfinedType<MemRefRankOf<allowedTypes, [2]>, [HasStaticShapePred, isSharedPred],
"reside in share memory and statically 2d shaped " # MemRefOf<allowedTypes>.summary # " ",
"mlir::MemRefType">;

def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
[ShapedTypeInterface], "::mlir::TensorType"> {
let summary = "TensorDesc describing regions of interested data.";
Expand Down
38 changes: 9 additions & 29 deletions mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,6 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
}
};

// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
// 32 bits will be converted to 32 bits.
class CreateMemDescOpPattern final
: public OpConversionPattern<xegpu::CreateMemDescOp> {
public:
Expand All @@ -590,16 +587,7 @@ class CreateMemDescOpPattern final
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto resTy = op.getMemDesc();

// Create the result MemRefType with the same shape, element type, and
// memory space
auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);

Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
op.getSource(), zero, ValueRange());
rewriter.replaceOp(op, viewOp);
rewriter.replaceOp(op, adaptor.getSource());
return success();
}
};
Expand All @@ -619,7 +607,7 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {

auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
Value basePtrStruct = adaptor.getMemDesc();
Value baseAddr32 = adaptor.getMemDesc();
Value mdescVal = op.getMemDesc();
// Load result or Store value Type can be vector or scalar.
Value data;
Expand Down Expand Up @@ -647,21 +635,14 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {

auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());

Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
rewriter, loc, basePtrStruct);

// Convert base pointer (ptr) to i32
Value basePtrI32 = arith::IndexCastUIOp::create(
rewriter, loc, rewriter.getI32Type(), basePtrLLVM);

Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
linearOffset = arith::IndexCastUIOp::create(
rewriter, loc, rewriter.getI32Type(), linearOffset);
basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
elemByteSize);
Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
linearOffset, elemByteSize);

// convert base pointer (i32) to LLVM pointer type
basePtrLLVM =
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);

if (op.getSubgroupBlockIoAttr()) {
Expand Down Expand Up @@ -1005,15 +986,14 @@ struct ConvertXeGPUToXeVMPass
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
// Convert MemDescType into flattened MemRefType for SLM
// Convert MemDescType into i32 for SLM
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
Type elemTy = type.getElementType();
int numElems = type.getNumElements();
return MemRefType::get(numElems, elemTy, AffineMap(), 3);
return IntegerType::get(&getContext(), 32);
});

typeConverter.addConversion([&](MemRefType type) -> Type {
// Convert MemRefType to i64 type.
if (type.getMemorySpaceAsInt() == 3)
return IntegerType::get(&getContext(), 32);
return IntegerType::get(&getContext(), 64);
});

Expand Down
81 changes: 52 additions & 29 deletions mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {

// e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
// its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
//CHECK-LABEL: load_store_matrix_1
gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
//CHECK-LABEL: load_store_matrix_plain
gpu.func @load_store_matrix_plain(%arg0: memref<4096xi8, 3>) -> f32 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>

//CHECK: %[[TID:.*]] = gpu.thread_id x
Expand All @@ -26,20 +26,48 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
gpu.return %1: f32
}

//CHECK-LABEL: load_store_matrix_plain_2d_input
gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 {
%c0 = arith.constant 0 : index
%view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>

%subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>

%0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>

//CHECK: %[[TID:.*]] = gpu.thread_id x
//CHECK: %[[C1:.*]] = arith.constant 1 : index
//CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
//CHECK: %[[C4:.*]] = arith.constant 4 : i32
//CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32

%tid_x = gpu.thread_id x

%1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32

//CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>

xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index

gpu.return %1: f32
}


// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
//CHECK-LABEL: load_store_matrix_2
gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
//CHECK-LABEL: load_store_matrix_blocked_strided
gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
//CHECK: %[[c0:.*]] = arith.constant 0 : index

//CHECK: %[[tid_x:.*]] = gpu.thread_id x
//CHECK: %[[c13:.*]] = arith.constant 13 : index
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index

//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
Expand Down Expand Up @@ -68,24 +96,24 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {

// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
//CHECK-LABEL: load_store_matrix_3
gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
//CHECK-LABEL: load_store_matrix_blocked_nostride
gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {

//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>

//CHECK: %[[tid_x:.*]] = gpu.thread_id x
//CHECK: %[[c19:.*]] = arith.constant 19 : index
%tid_x = gpu.thread_id x
%c19 = arith.constant 19: index

//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
Expand All @@ -97,7 +125,6 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[c1:.*]] = arith.constant 1 : index
//CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index

//CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
%1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16

Expand All @@ -110,19 +137,17 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {

// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
//CHECK-LABEL: load_store_matrix_4
gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
//CHECK-LABEL: load_store_matrix_blocked_strided_return_vector
gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>

//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[tid_x:.*]] = gpu.thread_id x

//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index

//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
Expand Down Expand Up @@ -150,25 +175,23 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {

// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
//CHECK-LABEL: load_store_matrix_5
gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>

%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
//CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {

//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
//CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>

//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[c48:.*]] = arith.constant 48 : index

%c16 = arith.constant 16 : index
%c48 = arith.constant 48 : index

//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
//CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
//CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
//CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
//CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
Expand All @@ -183,7 +206,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
//CHECK: %[[c2:.*]] = arith.constant 2 : i32
//CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
//CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32
//CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32]], %[[byteOffset]] : i32
//CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
//CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
//CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/XeGPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ func.func @slice_attr_repeat_dim() {
// -----
func.func @create_mem_desc_non_slm() {
%m = memref.alloca() {alignment = 1024} : memref<2048xi8, 1>
// expected-error@+1 {{operand #0 must be statically shaped memref of 8-bit signless integer values for shared memory}}
// expected-error@+1 {{operand #0 must be reside in share memory and statically 1d shaped memref }}
%mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 1> -> !xegpu.mem_desc<16x64xf16>
return
}
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/XeGPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,27 @@ gpu.func @create_mem_desc_with_stride() {
gpu.return
}


// CHECK-LABEL: gpu.func @create_mem_desc_from_2d_memref({{.*}}) {
gpu.func @create_mem_desc_from_2d_memref() {
//CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<16x64xf16, 3>
//CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
%m = memref.alloca() {alignment = 1024} : memref<16x64xf16, 3>
%mem_desc = xegpu.create_mem_desc %m : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
gpu.return
}

// CHECK-LABEL: gpu.func @create_mem_desc_with_stride_from_2d_memref({{.*}}) {
gpu.func @create_mem_desc_with_stride_from_2d_memref() {
//CHECK: %[[ALLOC:.+]] = memref.alloca() {alignment = 1024 : i64} : memref<32x64xf16, 3>
//CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16, 0] [16, 64] [1, 1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
//CHECK: %{{.+}} = xegpu.create_mem_desc %[[SUBVIEW]] : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
%m = memref.alloca() {alignment = 1024} : memref<32x64xf16, 3>
%m_sub = memref.subview %m[16, 0][16, 64][1,1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
%mem_desc = xegpu.create_mem_desc %m_sub : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
gpu.return
}

// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
Expand Down