From 43d9ddb971ed0d80a44ef95d6d7751c5996337bb Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Tue, 2 Dec 2025 00:13:27 +0000 Subject: [PATCH 1/2] support memref subview in xegpu to xevm type conversion --- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 43 ++++++++++++++++--- .../XeGPUToXeVM/loadstore_matrix.mlir | 22 ++++++---- 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 7f1ec17ce0ae8..bafd1dc348e5b 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -999,7 +999,7 @@ struct ConvertXeGPUToXeVMPass // LLVM type converter puts unrealized casts for the following cases: // add materialization casts to handle them. - // Materialization to convert memref to i64 + // Materialization to convert memref to i64 or i32 depending on global/SLM auto memrefMaterializationCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { @@ -1007,11 +1007,44 @@ struct ConvertXeGPUToXeVMPass return {}; auto input = inputs.front(); if (auto memrefTy = dyn_cast(input.getType())) { + unsigned rank = memrefTy.getRank(); + Type indexType = builder.getIndexType(); - Value addr = - memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input); - return arith::IndexCastUIOp::create(builder, loc, type, addr) - .getResult(); + SmallVector resultTypes; + // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1, + // size0, size1, ..., sizeN-1] + resultTypes.push_back(MemRefType::get( + {}, memrefTy.getElementType(), MemRefLayoutAttrInterface(), + memrefTy.getMemorySpace())); // base memref (unranked) + resultTypes.push_back(indexType); // offset + for (unsigned i = 0; i < rank; ++i) + resultTypes.push_back(indexType); // strides + for (unsigned i = 0; i < rank; ++i) + resultTypes.push_back(indexType); // sizes + + auto meta = memref::ExtractStridedMetadataOp::create( + builder, loc, resultTypes, input); + + auto addr = memref::ExtractAlignedPointerAsIndexOp::create( + builder, loc, meta.getBaseBuffer()); + auto offset = meta.getOffset(); + + auto addr_casted = + arith::IndexCastUIOp::create(builder, loc, type, addr); + auto offset_casted = + arith::IndexCastUIOp::create(builder, loc, type, offset); + + // Compute the final address: base address + byte offset + auto byte_size = arith::ConstantOp::create( + builder, loc, type, + builder.getIntegerAttr(type, + memrefTy.getElementTypeBitWidth() / 8)); + auto byte_offset = + arith::MulIOp::create(builder, loc, offset_casted, byte_size); + auto addr_with_offset = + arith::AddIOp::create(builder, loc, addr_casted, byte_offset); + + return addr_with_offset.getResult(); } return {}; }; diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir index ac95a1a5707ea..aba73b80f0439 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -33,22 +33,29 @@ gpu.module @test_kernel [#xevm.target] { %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> + //CHECK: %[[base_buffer:.*]], %[[offset:.*]], %[[sizes:.*]]:2, %[[strides:.*]]:2 = memref.extract_strided_metadata %{{.*}} : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> memref, index, index, index, index, index + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer]] : memref -> index + //CHECK: %[[ptr_i32:.*]] = arith.index_castui %[[intptr]] : index to i32 + //CHECK: %[[offset_i32:.*]] = arith.index_castui %[[offset]] : index to i32 + //CHECK: %[[c4_i32:.*]] = arith.constant 4 : i32 + //CHECK: %[[mul:.*]] = arith.muli %[[offset_i32]], %[[c4_i32]] : i32 + //CHECK: %[[add:.*]] = arith.addi %[[ptr_i32]], %[[mul]] : i32 + %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32> //CHECK: %[[TID:.*]] = gpu.thread_id x //CHECK: %[[C1:.*]] = arith.constant 1 : index //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index - //CHECK: %[[C4:.*]] = arith.constant 4 : i32 - //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32 + //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, {{.*}} : i32 //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32 %tid_x = gpu.thread_id x - + %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32 //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3> - xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index + xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index gpu.return %1: f32 } @@ -99,8 +106,6 @@ gpu.module @test_kernel [#xevm.target] { //CHECK-LABEL: load_store_matrix_blocked_nostride gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 { - //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index - //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> //CHECK: %[[tid_x:.*]] = gpu.thread_id x @@ -178,7 +183,8 @@ gpu.module @test_kernel [#xevm.target] { //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { - //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index + //CHECK: %[[base_buffer:.*]], %[[offset:.*]], %[[sizes:.*]], %[[strides:.*]] = memref.extract_strided_metadata %arg0 : memref<4096xi8, 3> -> memref, index, index, index + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer]] : memref -> index //CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32 %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> @@ -206,7 +212,7 @@ gpu.module @test_kernel [#xevm.target] { //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32 //CHECK: %[[c2:.*]] = arith.constant 2 : i32 //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32 - //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32]], %[[byteOffset]] : i32 + //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32:.*]], %[[byteOffset]] : i32 //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3> //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16> //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16> From 2664dc1c1bc88e531db8c5aa40fe1bfb3004b9d0 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Wed, 3 Dec 2025 19:21:54 +0000 Subject: [PATCH 2/2] add static offset support --- .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 3 + .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 52 +++-- mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 12 ++ .../XeGPUToXeVM/create_nd_tdesc.mlir | 8 +- .../XeGPUToXeVM/loadstore_matrix.mlir | 188 +++++++++--------- .../Conversion/XeGPUToXeVM/loadstore_nd.mlir | 4 +- 6 files changed, 146 insertions(+), 121 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h index 58092c3bb9ed2..b5978dc8d7b74 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h @@ -175,6 +175,9 @@ template int getLargestDivisor(T dim, ArrayRef candidates, ArrayRef candidateMultiples = {}); +/// Checks if the given MemRefType refers to shared memory. +bool isSharedMemRef(const MemRefType &memrefTy); + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index bafd1dc348e5b..a1c2745864dc6 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -991,9 +991,8 @@ struct ConvertXeGPUToXeVMPass }); typeConverter.addConversion([&](MemRefType type) -> Type { - if (type.getMemorySpaceAsInt() == 3) - return IntegerType::get(&getContext(), 32); - return IntegerType::get(&getContext(), 64); + return IntegerType::get(&getContext(), + (xegpu::isSharedMemRef(type) ? 32 : 64)); }); // LLVM type converter puts unrealized casts for the following cases: @@ -1010,24 +1009,35 @@ struct ConvertXeGPUToXeVMPass unsigned rank = memrefTy.getRank(); Type indexType = builder.getIndexType(); - SmallVector resultTypes; - // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1, - // size0, size1, ..., sizeN-1] - resultTypes.push_back(MemRefType::get( - {}, memrefTy.getElementType(), MemRefLayoutAttrInterface(), - memrefTy.getMemorySpace())); // base memref (unranked) - resultTypes.push_back(indexType); // offset - for (unsigned i = 0; i < rank; ++i) - resultTypes.push_back(indexType); // strides - for (unsigned i = 0; i < rank; ++i) - resultTypes.push_back(indexType); // sizes - - auto meta = memref::ExtractStridedMetadataOp::create( - builder, loc, resultTypes, input); - - auto addr = memref::ExtractAlignedPointerAsIndexOp::create( - builder, loc, meta.getBaseBuffer()); - auto offset = meta.getOffset(); + int64_t intOffsets; + SmallVector intStrides; + Value addr; + Value offset; + if (failed(memrefTy.getStridesAndOffset(intStrides, intOffsets))) { + + // Result types: [base_memref, offset, stride0, stride1, ..., + // strideN-1, size0, size1, ..., sizeN-1] + SmallVector resultTypes{ + MemRefType::get({}, memrefTy.getElementType(), + MemRefLayoutAttrInterface(), + memrefTy.getMemorySpace()), + indexType}; + // strides + sizes + resultTypes.append(2 * rank, indexType); + + auto meta = memref::ExtractStridedMetadataOp::create( + builder, loc, resultTypes, input); + + addr = memref::ExtractAlignedPointerAsIndexOp::create( + builder, loc, meta.getBaseBuffer()); + offset = meta.getOffset(); + + } else { + addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, + input); + offset = arith::ConstantOp::create(builder, loc, + builder.getIndexAttr(intOffsets)); + } auto addr_casted = arith::IndexCastUIOp::create(builder, loc, type, addr); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 91432b1c11304..eecbb7b907e9f 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -580,3 +580,15 @@ template int xegpu::getLargestDivisor(int dim, ArrayRef candidates, template int xegpu::getLargestDivisor(unsigned dim, ArrayRef candidates, ArrayRef candidateMultiples); + +/// Checks if the given MemRefType refers to shared memory. +bool xegpu::isSharedMemRef(const MemRefType &memrefTy) { + Attribute attr = memrefTy.getMemorySpace(); + if (!attr) + return false; + if (auto intAttr = llvm::dyn_cast(attr)) + return intAttr.getInt() == static_cast(xevm::AddrSpace::SHARED); + if (auto xevmSpace = llvm::dyn_cast(attr)) + return xevmSpace.getValue() == xevm::AddrSpace::SHARED; + return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr); +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir index 8b87b791c9fd3..242101955b900 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -8,7 +8,8 @@ gpu.module @create_nd_tdesc { gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index, %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref) kernel { // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref -> index - // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64 + // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64 + // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> @@ -39,7 +40,7 @@ gpu.module @create_nd_tdesc { // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64 // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32 // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64> - // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64> + // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2_OFFSET:.*]], %[[VAR14]] [0] : i64 into vector<4xi64> // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32> // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32> // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32> @@ -53,13 +54,14 @@ gpu.module @create_nd_tdesc { %size_x = arith.constant 64 : index // CHECK: %[[C16:.*]] = arith.constant 16 : index %BLOCK_DMODEL = arith.constant 16 : index + // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32> // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32 // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32 // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32 // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32 // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64> - // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64> + // CHECK: %[[VAR25:.*]] = vector.insert %[[BASE_ADDR3_OFFSET:.*]], %[[VAR24]] [0] : i64 into vector<4xi64> // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32> // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32> // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32> diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir index aba73b80f0439..179fd397d7074 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -33,21 +33,20 @@ gpu.module @test_kernel [#xevm.target] { %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> - //CHECK: %[[base_buffer:.*]], %[[offset:.*]], %[[sizes:.*]]:2, %[[strides:.*]]:2 = memref.extract_strided_metadata %{{.*}} : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> memref, index, index, index, index, index - //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer]] : memref -> index - //CHECK: %[[ptr_i32:.*]] = arith.index_castui %[[intptr]] : index to i32 - //CHECK: %[[offset_i32:.*]] = arith.index_castui %[[offset]] : index to i32 - //CHECK: %[[c4_i32:.*]] = arith.constant 4 : i32 - //CHECK: %[[mul:.*]] = arith.muli %[[offset_i32]], %[[c4_i32]] : i32 - //CHECK: %[[add:.*]] = arith.addi %[[ptr_i32]], %[[mul]] : i32 + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer:.*]] : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> index + //CHECK-DAG: %[[ptr_i32:.*]] = arith.index_castui %[[intptr]] : index to i32 + //CHECK-DAG: %[[offset_i32:.*]] = arith.index_castui %[[offset:.*]] : index to i32 + //CHECK-DAG: %[[c4_i32:.*]] = arith.constant 4 : i32 + //CHECK-DAG: %[[mul:.*]] = arith.muli %[[offset_i32]], %[[c4_i32]] : i32 + //CHECK-DAG: %[[add:.*]] = arith.addi %[[ptr_i32]], %[[mul]] : i32 %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32> - //CHECK: %[[TID:.*]] = gpu.thread_id x - //CHECK: %[[C1:.*]] = arith.constant 1 : index - //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index - //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, {{.*}} : i32 - //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32 + //CHECK-DAG: %[[TID:.*]] = gpu.thread_id x + //CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + //CHECK-DAG: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index + //CHECK-DAG: %[[MUL2:.*]] = arith.muli {{.*}}, {{.*}} : i32 + //CHECK-DAG: llvm.load {{.*}} : !llvm.ptr<3> -> f32 %tid_x = gpu.thread_id x @@ -67,25 +66,25 @@ gpu.module @test_kernel [#xevm.target] { gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - //CHECK: %[[tid_x:.*]] = gpu.thread_id x - //CHECK: %[[c13:.*]] = arith.constant 13 : index - //CHECK: %[[c16:.*]] = arith.constant 16 : index - //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index - //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index - //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index - //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[c256:.*]] = arith.constant 256 : index - //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index - //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index - //CHECK: %[[c512:.*]] = arith.constant 512 : index - //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index - //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index - //CHECK: %[[c1:.*]] = arith.constant 1 : index - //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index - //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index - //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index - //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + //CHECK-DAG: %[[tid_x:.*]] = gpu.thread_id x + //CHECK-DAG: %[[c13:.*]] = arith.constant 13 : index + //CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index + //CHECK-DAG: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index + //CHECK-DAG: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index + //CHECK-DAG: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK-DAG: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index + //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index + //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK-DAG: %[[c512:.*]] = arith.constant 512 : index + //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index + //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index + //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index + //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index + //CHECK-DAG: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16 @@ -108,29 +107,29 @@ gpu.module @test_kernel [#xevm.target] { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - //CHECK: %[[tid_x:.*]] = gpu.thread_id x - //CHECK: %[[c19:.*]] = arith.constant 19 : index + //CHECK-DAG: %[[tid_x:.*]] = gpu.thread_id x + //CHECK-DAG: %[[c19:.*]] = arith.constant 19 : index %tid_x = gpu.thread_id x %c19 = arith.constant 19: index - //CHECK: %[[c16:.*]] = arith.constant 16 : index - //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index - //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index - //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index - //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[c1024:.*]] = arith.constant 1024 : index - //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index - //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index - //CHECK: %[[c256:.*]] = arith.constant 256 : index - //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index - //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index - //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index - //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index - //CHECK: %[[c1:.*]] = arith.constant 1 : index - //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index - //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index - //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16 + //CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index + //CHECK-DAG: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index + //CHECK-DAG: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index + //CHECK-DAG: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK-DAG: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + //CHECK-DAG: %[[c1024:.*]] = arith.constant 1024 : index + //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index + //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index + //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index + //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index + //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index + //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index + //CHECK-DAG: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + //CHECK-DAG: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16 %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> f16 //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> @@ -146,24 +145,24 @@ gpu.module @test_kernel [#xevm.target] { gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - //CHECK: %[[tid_x:.*]] = gpu.thread_id x - //CHECK: %[[c16:.*]] = arith.constant 16 : index - //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index - //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index - //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index - //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[c256:.*]] = arith.constant 256 : index - //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index - //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index - //CHECK: %[[c512:.*]] = arith.constant 512 : index - //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index - //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index - //CHECK: %[[c1:.*]] = arith.constant 1 : index - //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index - //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index - //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index - //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + //CHECK-DAG: %[[tid_x:.*]] = gpu.thread_id x + //CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index + //CHECK-DAG: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index + //CHECK-DAG: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index + //CHECK-DAG: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK-DAG: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index + //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index + //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK-DAG: %[[c512:.*]] = arith.constant 512 : index + //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index + //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index + //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index + //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index + //CHECK-DAG: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16> @@ -183,9 +182,9 @@ gpu.module @test_kernel [#xevm.target] { //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { - //CHECK: %[[base_buffer:.*]], %[[offset:.*]], %[[sizes:.*]], %[[strides:.*]] = memref.extract_strided_metadata %arg0 : memref<4096xi8, 3> -> memref, index, index, index - //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer]] : memref -> index - //CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32 + //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + //CHECK-DAG: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer:.*]] : memref<4096xi8, 3> -> index + //CHECK-DAG: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32 %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> //CHECK: %[[c16:.*]] = arith.constant 16 : index @@ -193,29 +192,28 @@ gpu.module @test_kernel [#xevm.target] { %c16 = arith.constant 16 : index %c48 = arith.constant 48 : index - //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index - //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index - //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index - //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[c1024:.*]] = arith.constant 1024 : index - //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index - //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index - //CHECK: %[[c256:.*]] = arith.constant 256 : index - //CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index - //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index - //CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index - //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index - //CHECK: %[[c1:.*]] = arith.constant 1 : index - //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index - //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index - //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32 - //CHECK: %[[c2:.*]] = arith.constant 2 : i32 - //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32 - //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32:.*]], %[[byteOffset]] : i32 - //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3> - //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16> - //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16> + //CHECK-DAG: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index + //CHECK-DAG: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index + //CHECK-DAG: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index + //CHECK-DAG: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index + //CHECK-DAG: %[[c1024:.*]] = arith.constant 1024 : index + //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index + //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index + //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index + //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index + //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index + //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index + //CHECK-DAG: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index + //CHECK-DAG: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32 + //CHECK-DAG: %[[c2:.*]] = arith.constant 2 : i32 + //CHECK-DAG: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32 + //CHECK-DAG: %[[finalPtr:.*]] = arith.addi %[[basePtrI32:.*]], %[[byteOffset]] : i32 + //CHECK-DAG: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3> + //CHECK-DAG: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16> + //CHECK-DAG: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16> %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<8xf16> diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir index afeae8be24b72..30fbb66ec9e58 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir @@ -14,7 +14,7 @@ gpu.module @load_store_check { // CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<8x16xf32> -> index // CHECK: %[[ST_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR_1]] : index to i64 // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> - // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> + // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64_OFFSET:.*]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32> // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32> // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32> @@ -50,7 +50,7 @@ gpu.module @load_store_check { %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32> // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> - // CHECK: %[[DESC_0:.*]] = vector.insert %[[ST_PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64> + // CHECK: %[[DESC_0:.*]] = vector.insert %[[ST_PTR_AS_I64_OFFSET:.*]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64> // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32> // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32> // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>