-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][XeGPU] Support subview memref: handling the base address during xegpu to xevm type conversion #170541
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Jianhui Li (Jianhui-Li) ChangesDuring the XeGPU-to-XeVM type conversion, a memref is lowered to its base address. This PR extends the conversion to correctly handle memrefs that include an offset, such as those generated by memref.subview. Patch is 25.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170541.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 58092c3bb9ed2..b5978dc8d7b74 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -175,6 +175,9 @@ template <typename T>
int getLargestDivisor(T dim, ArrayRef<T> candidates,
ArrayRef<T> candidateMultiples = {});
+/// Checks if the given MemRefType refers to shared memory.
+bool isSharedMemRef(const MemRefType &memrefTy);
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 7f1ec17ce0ae8..a1c2745864dc6 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -991,15 +991,14 @@ struct ConvertXeGPUToXeVMPass
});
typeConverter.addConversion([&](MemRefType type) -> Type {
- if (type.getMemorySpaceAsInt() == 3)
- return IntegerType::get(&getContext(), 32);
- return IntegerType::get(&getContext(), 64);
+ return IntegerType::get(&getContext(),
+ (xegpu::isSharedMemRef(type) ? 32 : 64));
});
// LLVM type converter puts unrealized casts for the following cases:
// add materialization casts to handle them.
- // Materialization to convert memref to i64
+ // Materialization to convert memref to i64 or i32 depending on global/SLM
auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
ValueRange inputs,
Location loc) -> Value {
@@ -1007,11 +1006,55 @@ struct ConvertXeGPUToXeVMPass
return {};
auto input = inputs.front();
if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
+ unsigned rank = memrefTy.getRank();
+ Type indexType = builder.getIndexType();
- Value addr =
- memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
- return arith::IndexCastUIOp::create(builder, loc, type, addr)
- .getResult();
+ int64_t intOffsets;
+ SmallVector<int64_t> intStrides;
+ Value addr;
+ Value offset;
+ if (failed(memrefTy.getStridesAndOffset(intStrides, intOffsets))) {
+
+ // Result types: [base_memref, offset, stride0, stride1, ...,
+ // strideN-1, size0, size1, ..., sizeN-1]
+ SmallVector<Type> resultTypes{
+ MemRefType::get({}, memrefTy.getElementType(),
+ MemRefLayoutAttrInterface(),
+ memrefTy.getMemorySpace()),
+ indexType};
+ // strides + sizes
+ resultTypes.append(2 * rank, indexType);
+
+ auto meta = memref::ExtractStridedMetadataOp::create(
+ builder, loc, resultTypes, input);
+
+ addr = memref::ExtractAlignedPointerAsIndexOp::create(
+ builder, loc, meta.getBaseBuffer());
+ offset = meta.getOffset();
+
+ } else {
+ addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
+ input);
+ offset = arith::ConstantOp::create(builder, loc,
+ builder.getIndexAttr(intOffsets));
+ }
+
+ auto addr_casted =
+ arith::IndexCastUIOp::create(builder, loc, type, addr);
+ auto offset_casted =
+ arith::IndexCastUIOp::create(builder, loc, type, offset);
+
+ // Compute the final address: base address + byte offset
+ auto byte_size = arith::ConstantOp::create(
+ builder, loc, type,
+ builder.getIntegerAttr(type,
+ memrefTy.getElementTypeBitWidth() / 8));
+ auto byte_offset =
+ arith::MulIOp::create(builder, loc, offset_casted, byte_size);
+ auto addr_with_offset =
+ arith::AddIOp::create(builder, loc, addr_casted, byte_offset);
+
+ return addr_with_offset.getResult();
}
return {};
};
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 91432b1c11304..eecbb7b907e9f 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -580,3 +580,15 @@ template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
template int
xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
ArrayRef<unsigned> candidateMultiples);
+
+/// Checks if the given MemRefType refers to shared memory.
+bool xegpu::isSharedMemRef(const MemRefType &memrefTy) {
+ Attribute attr = memrefTy.getMemorySpace();
+ if (!attr)
+ return false;
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
+ return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
+ if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
+ return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
+ return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 8b87b791c9fd3..242101955b900 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -8,7 +8,8 @@ gpu.module @create_nd_tdesc {
gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
%stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
// CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
- // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
+ // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
+
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
@@ -39,7 +40,7 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
- // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2_OFFSET:.*]], %[[VAR14]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
@@ -53,13 +54,14 @@ gpu.module @create_nd_tdesc {
%size_x = arith.constant 64 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
%BLOCK_DMODEL = arith.constant 16 : index
+
// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
// CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
// CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
// CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
// CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
- // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR25:.*]] = vector.insert %[[BASE_ADDR3_OFFSET:.*]], %[[VAR24]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index ac95a1a5707ea..179fd397d7074 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -33,22 +33,28 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
%subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer:.*]] : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> index
+ //CHECK-DAG: %[[ptr_i32:.*]] = arith.index_castui %[[intptr]] : index to i32
+ //CHECK-DAG: %[[offset_i32:.*]] = arith.index_castui %[[offset:.*]] : index to i32
+ //CHECK-DAG: %[[c4_i32:.*]] = arith.constant 4 : i32
+ //CHECK-DAG: %[[mul:.*]] = arith.muli %[[offset_i32]], %[[c4_i32]] : i32
+ //CHECK-DAG: %[[add:.*]] = arith.addi %[[ptr_i32]], %[[mul]] : i32
+
%0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
- //CHECK: %[[TID:.*]] = gpu.thread_id x
- //CHECK: %[[C1:.*]] = arith.constant 1 : index
- //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
- //CHECK: %[[C4:.*]] = arith.constant 4 : i32
- //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
- //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
+ //CHECK-DAG: %[[TID:.*]] = gpu.thread_id x
+ //CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ //CHECK-DAG: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
+ //CHECK-DAG: %[[MUL2:.*]] = arith.muli {{.*}}, {{.*}} : i32
+ //CHECK-DAG: llvm.load {{.*}} : !llvm.ptr<3> -> f32
%tid_x = gpu.thread_id x
-
+
%1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
//CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
- xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
+ xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
gpu.return %1: f32
}
@@ -60,25 +66,25 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
- //CHECK: %[[tid_x:.*]] = gpu.thread_id x
- //CHECK: %[[c13:.*]] = arith.constant 13 : index
- //CHECK: %[[c16:.*]] = arith.constant 16 : index
- //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
- //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
- //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
- //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
- //CHECK: %[[c256:.*]] = arith.constant 256 : index
- //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
- //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
- //CHECK: %[[c512:.*]] = arith.constant 512 : index
- //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
- //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
- //CHECK: %[[c1:.*]] = arith.constant 1 : index
- //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
- //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
- //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
- //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+ //CHECK-DAG: %[[tid_x:.*]] = gpu.thread_id x
+ //CHECK-DAG: %[[c13:.*]] = arith.constant 13 : index
+ //CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+ //CHECK-DAG: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
+ //CHECK-DAG: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
+ //CHECK-DAG: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+ //CHECK-DAG: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+ //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index
+ //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+ //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK-DAG: %[[c512:.*]] = arith.constant 512 : index
+ //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+ //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+ //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+ //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+ //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+ //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+ //CHECK-DAG: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
//CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
@@ -99,33 +105,31 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK-LABEL: load_store_matrix_blocked_nostride
gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {
- //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
- //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
- //CHECK: %[[tid_x:.*]] = gpu.thread_id x
- //CHECK: %[[c19:.*]] = arith.constant 19 : index
+ //CHECK-DAG: %[[tid_x:.*]] = gpu.thread_id x
+ //CHECK-DAG: %[[c19:.*]] = arith.constant 19 : index
%tid_x = gpu.thread_id x
%c19 = arith.constant 19: index
- //CHECK: %[[c16:.*]] = arith.constant 16 : index
- //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
- //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
- //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
- //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
- //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
- //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
- //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
- //CHECK: %[[c256:.*]] = arith.constant 256 : index
- //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
- //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
- //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
- //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
- //CHECK: %[[c1:.*]] = arith.constant 1 : index
- //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
- //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
- //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
+ //CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+ //CHECK-DAG: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
+ //CHECK-DAG: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
+ //CHECK-DAG: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+ //CHECK-DAG: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+ //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK-DAG: %[[c1024:.*]] = arith.constant 1024 : index
+ //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
+ //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index
+ //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
+ //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+ //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
+ //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+ //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+ //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
+ //CHECK-DAG: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+ //CHECK-DAG: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
%1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
//CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
@@ -141,24 +145,24 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
- //CHECK: %[[tid_x:.*]] = gpu.thread_id x
- //CHECK: %[[c16:.*]] = arith.constant 16 : index
- //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
- //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
- //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
- //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
- //CHECK: %[[c256:.*]] = arith.constant 256 : index
- //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
- //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
- //CHECK: %[[c512:.*]] = arith.constant 512 : index
- //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
- //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
- //CHECK: %[[c1:.*]] = arith.constant 1 : index
- //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
- //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
- //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
- //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+ //CHECK-DAG: %[[tid_x:.*]] = gpu.thread_id x
+ //CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+ //CHECK-DAG: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
+ //CHECK-DAG: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
+ //CHECK-DAG: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+ //CHECK-DAG: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+ //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index
+ //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+ //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK-DAG: %[[c512:.*]] = arith.constant 512 : index
+ //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+ //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+ //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+ //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+ //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+ //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+ //CHECK-DAG: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
//CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
@@ -178,8 +182,9 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
- //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
- //CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
+ //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK-DAG: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer:.*]] : memref<4096xi8, 3> -> index
+ //CHECK-DAG: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
%0 = xegpu.create_mem_...
[truncated]
|
silee2
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
charithaintc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
waiting for some clarifications before digging deep.
| ArrayRef<T> candidateMultiples = {}); | ||
|
|
||
| /// Checks if the given MemRefType refers to shared memory. | ||
| bool isSharedMemRef(const MemRefType &memrefTy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This is only used in a single file for now (XeGPUToXeVM). For now you can have this as a helper inside that file. If more uses arise we can have a common place for it.
| builder.getIndexAttr(intOffsets)); | ||
| } | ||
|
|
||
| auto addr_casted = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LLVM does not use snake case variable naming. rename to addrCasted
| SmallVector<int64_t> intStrides; | ||
| Value addr; | ||
| Value offset; | ||
| if (failed(memrefTy.getStridesAndOffset(intStrides, intOffsets))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this check? why not use ExtractStridedMetadataOp for all cases. Above case could return ShapedType::kDynamic for dynamic values.
| addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, | ||
| input); | ||
| offset = arith::ConstantOp::create(builder, loc, | ||
| builder.getIndexAttr(intOffsets)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this could be kDynamic which is a special value? so this code is not correct.
During the XeGPU-to-XeVM type conversion, a memref is lowered to its base address. This PR extends the conversion to correctly handle memrefs that include an offset, such as those generated by memref.subview.