Skip to content

Conversation

@Jianhui-Li
Copy link
Contributor

During the XeGPU-to-XeVM type conversion, a memref is lowered to its base address. This PR extends the conversion to correctly handle memrefs that include an offset, such as those generated by memref.subview.

@llvmbot
Copy link
Member

llvmbot commented Dec 3, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Jianhui Li (Jianhui-Li)

Changes

During the XeGPU-to-XeVM type conversion, a memref is lowered to its base address. This PR extends the conversion to correctly handle memrefs that include an offset, such as those generated by memref.subview.


Patch is 25.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170541.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h (+3)
  • (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+51-8)
  • (modified) mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp (+12)
  • (modified) mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir (+5-3)
  • (modified) mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir (+96-92)
  • (modified) mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 58092c3bb9ed2..b5978dc8d7b74 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -175,6 +175,9 @@ template <typename T>
 int getLargestDivisor(T dim, ArrayRef<T> candidates,
                       ArrayRef<T> candidateMultiples = {});
 
+/// Checks if the given MemRefType refers to shared memory.
+bool isSharedMemRef(const MemRefType &memrefTy);
+
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 7f1ec17ce0ae8..a1c2745864dc6 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -991,15 +991,14 @@ struct ConvertXeGPUToXeVMPass
     });
 
     typeConverter.addConversion([&](MemRefType type) -> Type {
-      if (type.getMemorySpaceAsInt() == 3)
-        return IntegerType::get(&getContext(), 32);
-      return IntegerType::get(&getContext(), 64);
+      return IntegerType::get(&getContext(),
+                              (xegpu::isSharedMemRef(type) ? 32 : 64));
     });
 
     // LLVM type converter puts unrealized casts for the following cases:
     // add materialization casts to handle them.
 
-    // Materialization to convert memref to i64
+    // Materialization to convert memref to i64 or i32 depending on global/SLM
     auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
                                         ValueRange inputs,
                                         Location loc) -> Value {
@@ -1007,11 +1006,55 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
+        unsigned rank = memrefTy.getRank();
+        Type indexType = builder.getIndexType();
 
-        Value addr =
-            memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
-        return arith::IndexCastUIOp::create(builder, loc, type, addr)
-            .getResult();
+        int64_t intOffsets;
+        SmallVector<int64_t> intStrides;
+        Value addr;
+        Value offset;
+        if (failed(memrefTy.getStridesAndOffset(intStrides, intOffsets))) {
+
+          // Result types: [base_memref, offset, stride0, stride1, ...,
+          // strideN-1, size0, size1, ..., sizeN-1]
+          SmallVector<Type> resultTypes{
+              MemRefType::get({}, memrefTy.getElementType(),
+                              MemRefLayoutAttrInterface(),
+                              memrefTy.getMemorySpace()),
+              indexType};
+          // strides + sizes
+          resultTypes.append(2 * rank, indexType);
+
+          auto meta = memref::ExtractStridedMetadataOp::create(
+              builder, loc, resultTypes, input);
+
+          addr = memref::ExtractAlignedPointerAsIndexOp::create(
+              builder, loc, meta.getBaseBuffer());
+          offset = meta.getOffset();
+
+        } else {
+          addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
+                                                                input);
+          offset = arith::ConstantOp::create(builder, loc,
+                                             builder.getIndexAttr(intOffsets));
+        }
+
+        auto addr_casted =
+            arith::IndexCastUIOp::create(builder, loc, type, addr);
+        auto offset_casted =
+            arith::IndexCastUIOp::create(builder, loc, type, offset);
+
+        // Compute the final address: base address + byte offset
+        auto byte_size = arith::ConstantOp::create(
+            builder, loc, type,
+            builder.getIntegerAttr(type,
+                                   memrefTy.getElementTypeBitWidth() / 8));
+        auto byte_offset =
+            arith::MulIOp::create(builder, loc, offset_casted, byte_size);
+        auto addr_with_offset =
+            arith::AddIOp::create(builder, loc, addr_casted, byte_offset);
+
+        return addr_with_offset.getResult();
       }
       return {};
     };
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 91432b1c11304..eecbb7b907e9f 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -580,3 +580,15 @@ template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
 template int
 xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
                                    ArrayRef<unsigned> candidateMultiples);
+
+/// Checks if the given MemRefType refers to shared memory.
+bool xegpu::isSharedMemRef(const MemRefType &memrefTy) {
+  Attribute attr = memrefTy.getMemorySpace();
+  if (!attr)
+    return false;
+  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
+    return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
+  if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
+    return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
+  return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 8b87b791c9fd3..242101955b900 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -8,7 +8,8 @@ gpu.module @create_nd_tdesc {
   gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
   %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
         // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
-        // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
+        // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64    
+        
         // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
         // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
         // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
@@ -39,7 +40,7 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
         // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
         // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2_OFFSET:.*]], %[[VAR14]] [0] : i64 into vector<4xi64>
         // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
         // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
         // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
@@ -53,13 +54,14 @@ gpu.module @create_nd_tdesc {
         %size_x = arith.constant 64 : index
         // CHECK: %[[C16:.*]] = arith.constant 16 : index
         %BLOCK_DMODEL = arith.constant 16 : index
+   
         // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
         // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
         // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
         // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
         // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
         // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VAR25:.*]] = vector.insert %[[BASE_ADDR3_OFFSET:.*]], %[[VAR24]] [0] : i64 into vector<4xi64>
         // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
         // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
         // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index ac95a1a5707ea..179fd397d7074 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -33,22 +33,28 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     
     %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
 
+    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer:.*]] : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> index
+    //CHECK-DAG: %[[ptr_i32:.*]] = arith.index_castui %[[intptr]] : index to i32
+    //CHECK-DAG: %[[offset_i32:.*]] = arith.index_castui %[[offset:.*]] : index to i32
+    //CHECK-DAG: %[[c4_i32:.*]] = arith.constant 4 : i32
+    //CHECK-DAG: %[[mul:.*]] = arith.muli %[[offset_i32]], %[[c4_i32]] : i32
+    //CHECK-DAG: %[[add:.*]] = arith.addi %[[ptr_i32]], %[[mul]] : i32
+
     %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
 
-    //CHECK: %[[TID:.*]] = gpu.thread_id x
-    //CHECK: %[[C1:.*]] = arith.constant 1 : index
-    //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
-    //CHECK: %[[C4:.*]] = arith.constant 4 : i32
-    //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
-    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
+    //CHECK-DAG: %[[TID:.*]] = gpu.thread_id x
+    //CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+    //CHECK-DAG: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
+     //CHECK-DAG: %[[MUL2:.*]] = arith.muli {{.*}}, {{.*}} : i32
+    //CHECK-DAG: llvm.load {{.*}} : !llvm.ptr<3> -> f32
 
     %tid_x = gpu.thread_id x
- 
+  
     %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
 
     //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
 
-     xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
+    xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
 
     gpu.return %1: f32
   }
@@ -60,25 +66,25 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
 
-    //CHECK: %[[tid_x:.*]] = gpu.thread_id x
-    //CHECK: %[[c13:.*]] = arith.constant 13 : index
-    //CHECK: %[[c16:.*]] = arith.constant 16 : index
-    //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
-    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
-    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[c0:.*]] = arith.constant 0 : index
-    //CHECK: %[[c256:.*]] = arith.constant 256 : index
-    //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
-    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
-    //CHECK: %[[c512:.*]] = arith.constant 512 : index
-    //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
-    //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
-    //CHECK: %[[c1:.*]] = arith.constant 1 : index
-    //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
-    //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
-    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
-    //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+    //CHECK-DAG: %[[tid_x:.*]] = gpu.thread_id x
+    //CHECK-DAG: %[[c13:.*]] = arith.constant 13 : index
+    //CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+    //CHECK-DAG: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
+    //CHECK-DAG: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index
+    //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+    //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+    //CHECK-DAG: %[[c512:.*]] = arith.constant 512 : index
+    //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+    //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+    //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+    //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+    //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+    //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+    //CHECK-DAG: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
 
     //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
  
@@ -99,33 +105,31 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   //CHECK-LABEL: load_store_matrix_blocked_nostride
   gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {
 
-    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
-    //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
     
-    //CHECK: %[[tid_x:.*]] = gpu.thread_id x
-    //CHECK: %[[c19:.*]] = arith.constant 19 : index
+    //CHECK-DAG: %[[tid_x:.*]] = gpu.thread_id x
+    //CHECK-DAG: %[[c19:.*]] = arith.constant 19 : index
     %tid_x = gpu.thread_id x
     %c19 = arith.constant 19: index
     
-    //CHECK: %[[c16:.*]] = arith.constant 16 : index
-    //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
-    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
-    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[c0:.*]] = arith.constant 0 : index
-    //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
-    //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
-    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
-    //CHECK: %[[c256:.*]] = arith.constant 256 : index
-    //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
-    //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
-    //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
-    //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
-    //CHECK: %[[c1:.*]] = arith.constant 1 : index
-    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
-    //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
-    //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
+    //CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+    //CHECK-DAG: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
+    //CHECK-DAG: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK-DAG: %[[c1024:.*]] = arith.constant 1024 : index
+    //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
+    //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+    //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index
+    //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
+    //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+    //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
+    //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+    //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+    //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
+    //CHECK-DAG: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+    //CHECK-DAG: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
     %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
     
     //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
@@ -141,24 +145,24 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
 
-    //CHECK: %[[tid_x:.*]] = gpu.thread_id x
-    //CHECK: %[[c16:.*]] = arith.constant 16 : index
-    //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
-    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
-    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[c0:.*]] = arith.constant 0 : index
-    //CHECK: %[[c256:.*]] = arith.constant 256 : index
-    //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
-    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
-    //CHECK: %[[c512:.*]] = arith.constant 512 : index
-    //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
-    //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
-    //CHECK: %[[c1:.*]] = arith.constant 1 : index
-    //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
-    //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
-    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
-    //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+    //CHECK-DAG: %[[tid_x:.*]] = gpu.thread_id x
+    //CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+    //CHECK-DAG: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
+    //CHECK-DAG: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index
+    //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+    //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+    //CHECK-DAG: %[[c512:.*]] = arith.constant 512 : index
+    //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+    //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+    //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+    //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+    //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+    //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+    //CHECK-DAG: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
 
     //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
      
@@ -178,8 +182,9 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
   gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
 
-    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
-    //CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
+    //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK-DAG: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer:.*]] : memref<4096xi8, 3> -> index
+    //CHECK-DAG: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
      %0 = xegpu.create_mem_...
[truncated]

Copy link
Contributor

@silee2 silee2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

waiting for some clarifications before digging deep.

ArrayRef<T> candidateMultiples = {});

/// Checks if the given MemRefType refers to shared memory.
bool isSharedMemRef(const MemRefType &memrefTy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This is only used in a single file for now (XeGPUToXeVM). For now you can have this as a helper inside that file. If more uses arise we can have a common place for it.

builder.getIndexAttr(intOffsets));
}

auto addr_casted =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLVM does not use snake case variable naming. rename to addrCasted

SmallVector<int64_t> intStrides;
Value addr;
Value offset;
if (failed(memrefTy.getStridesAndOffset(intStrides, intOffsets))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this check? why not use ExtractStridedMetadataOp for all cases. Above case could return ShapedType::kDynamic for dynamic values.

addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
input);
offset = arith::ConstantOp::create(builder, loc,
builder.getIndexAttr(intOffsets));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be kDynamic which is a special value? so this code is not correct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants