Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1304,11 +1304,11 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
as the underlying shared local memory.

Arguments:
- `source` : a 1D statically shaped memref with element type i8, representing the raw SLM buffer.
- `source` : 1D or 2D statically shape memref, representing the raw SLM buffer. When the source is provided as 1D memref, its type must be i8.
Results:
- `mem_desc` : the memory descriptor.
}];
let arguments = (ins StaticShared1DMemRefOf<[I8]>:$source);
let arguments = (ins AnyTypeOf<[StaticShared1DMemRefOf<[I8]>, ConfinedType<MemRefRankOf<[XeGPU_ScalarType], [2]>, [HasStaticShapePred, isSharedPred]>]>:$source);
let results = (outs XeGPU_MemDesc:$mem_desc);
let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))";
}
Expand Down
35 changes: 11 additions & 24 deletions mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,16 +590,12 @@ class CreateMemDescOpPattern final
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto resTy = op.getMemDesc();
Value baseAddr = memref::ExtractAlignedPointerAsIndexOp::create(
rewriter, op.getLoc(), op.getSource());
auto baseAddr32 = arith::IndexCastUIOp::create(
rewriter, op.getLoc(), rewriter.getI32Type(), baseAddr);

// Create the result MemRefType with the same shape, element type, and
// memory space
auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);

Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
op.getSource(), zero, ValueRange());
rewriter.replaceOp(op, viewOp);
rewriter.replaceOp(op, baseAddr32);
return success();
}
};
Expand All @@ -619,7 +615,7 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {

auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
Value basePtrStruct = adaptor.getMemDesc();
Value baseAddr32 = adaptor.getMemDesc();
Value mdescVal = op.getMemDesc();
// Load result or Store value Type can be vector or scalar.
Value data;
Expand Down Expand Up @@ -647,21 +643,14 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {

auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());

Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
rewriter, loc, basePtrStruct);

// Convert base pointer (ptr) to i32
Value basePtrI32 = arith::IndexCastUIOp::create(
rewriter, loc, rewriter.getI32Type(), basePtrLLVM);

Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
linearOffset = arith::IndexCastUIOp::create(
rewriter, loc, rewriter.getI32Type(), linearOffset);
basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
elemByteSize);
Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
linearOffset, elemByteSize);

// convert base pointer (i32) to LLVM pointer type
basePtrLLVM =
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);

if (op.getSubgroupBlockIoAttr()) {
Expand Down Expand Up @@ -1005,11 +994,9 @@ struct ConvertXeGPUToXeVMPass
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
// Convert MemDescType into flattened MemRefType for SLM
// Convert MemDescType into i32 for SLM
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
Type elemTy = type.getElementType();
int numElems = type.getNumElements();
return MemRefType::get(numElems, elemTy, AffineMap(), 3);
return IntegerType::get(&getContext(), 32);
});

typeConverter.addConversion([&](MemRefType type) -> Type {
Expand Down
80 changes: 52 additions & 28 deletions mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {

// e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
// its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
//CHECK-LABEL: load_store_matrix_1
gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
//CHECK-LABEL: load_store_matrix_plain
gpu.func @load_store_matrix_plain(%arg0: memref<4096xi8, 3>) -> f32 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>

//CHECK: %[[TID:.*]] = gpu.thread_id x
Expand All @@ -26,20 +26,48 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
gpu.return %1: f32
}

//CHECK-LABEL: load_store_matrix_plain_2d_input
gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 {
%c0 = arith.constant 0 : index
%view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>

%subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>

%0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>

//CHECK: %[[TID:.*]] = gpu.thread_id x
//CHECK: %[[C1:.*]] = arith.constant 1 : index
//CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
//CHECK: %[[C4:.*]] = arith.constant 4 : i32
//CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32

%tid_x = gpu.thread_id x

%1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32

//CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>

xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index

gpu.return %1: f32
}


// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
//CHECK-LABEL: load_store_matrix_2
gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
//CHECK-LABEL: load_store_matrix_blocked_strided
gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
//CHECK: %[[c0:.*]] = arith.constant 0 : index

//CHECK: %[[tid_x:.*]] = gpu.thread_id x
//CHECK: %[[c13:.*]] = arith.constant 13 : index
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index

//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
Expand Down Expand Up @@ -68,24 +96,24 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {

// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
//CHECK-LABEL: load_store_matrix_3
gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
//CHECK-LABEL: load_store_matrix_blocked_nostride
gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {

//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>

//CHECK: %[[tid_x:.*]] = gpu.thread_id x
//CHECK: %[[c19:.*]] = arith.constant 19 : index
%tid_x = gpu.thread_id x
%c19 = arith.constant 19: index

//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
Expand All @@ -97,7 +125,6 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[c1:.*]] = arith.constant 1 : index
//CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index

//CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
%1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16

Expand All @@ -110,19 +137,17 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {

// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
//CHECK-LABEL: load_store_matrix_4
gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
//CHECK-LABEL: load_store_matrix_blocked_strided_return_vector
gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>

//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[tid_x:.*]] = gpu.thread_id x

//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index

//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
Expand Down Expand Up @@ -150,25 +175,24 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {

// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
//CHECK-LABEL: load_store_matrix_5
gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>

%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>

//CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {

//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: wrong type in the name, maybe a generic basePtr would be better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>


//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[c48:.*]] = arith.constant 48 : index

%c16 = arith.constant 16 : index
%c48 = arith.constant 48 : index

//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
//CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
//CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
//CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
//CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/XeGPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,27 @@ gpu.func @create_mem_desc_with_stride() {
gpu.return
}


// CHECK-LABEL: gpu.func @create_mem_desc_from_2d_memref({{.*}}) {
gpu.func @create_mem_desc_from_2d_memref() {
//CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<16x64xf16, 3>
//CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
%m = memref.alloca() {alignment = 1024} : memref<16x64xf16, 3>
%mem_desc = xegpu.create_mem_desc %m : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
gpu.return
}

// CHECK-LABEL: gpu.func @create_mem_desc_with_stride_from_2d_memref({{.*}}) {
gpu.func @create_mem_desc_with_stride_from_2d_memref() {
//CHECK: %[[ALLOC:.+]] = memref.alloca() {alignment = 1024 : i64} : memref<32x64xf16, 3>
//CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16, 0] [16, 64] [1, 1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
//CHECK: %{{.+}} = xegpu.create_mem_desc %[[SUBVIEW]] : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
%m = memref.alloca() {alignment = 1024} : memref<32x64xf16, 3>
%m_sub = memref.subview %m[16, 0][16, 64][1,1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
%mem_desc = xegpu.create_mem_desc %m_sub : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
gpu.return
}

// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
Expand Down
Loading