Skip to content

Conversation

@silee2
Copy link
Contributor

@silee2 silee2 commented Dec 2, 2025

Base memory pitch should be derived from base stride, not base width.
Remove offset fields from tensor descriptor payload and add pitch field.

@llvmbot
Copy link
Member

llvmbot commented Dec 2, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Sang Ik Lee (silee2)

Changes

Base memory pitch should be derived from base stride, not base width.


Patch is 22.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170384.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+24-25)
  • (modified) mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir (+16-21)
  • (modified) mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir (+7-53)
  • (modified) mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir (+5-16)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 7f1ec17ce0ae8..9c99a24bea8cd 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -50,11 +50,10 @@ static constexpr int32_t executionSize{16};
 
 // Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
 enum class NdTdescOffset : uint32_t {
-  BasePtr = 0,       // Base pointer (i64)
-  BaseShapeW = 2,    // Base shape width (i32)
-  BaseShapeH = 3,    // Base shape height (i32)
-  TensorOffsetW = 4, // Tensor offset W (i32)
-  TensorOffsetH = 5  // Tensor offset H (i32)
+  BasePtr = 0,    // Base pointer (i64)
+  BaseShapeW = 2, // Base shape width (i32)
+  BaseShapeH = 3, // Base shape height (i32)
+  BasePitch = 4,  // Base pitch (i32)
 };
 
 static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
@@ -179,11 +178,10 @@ class CreateNdDescToXeVMPattern
     Value baseAddr;
     Value baseShapeW;
     Value baseShapeH;
-    Value offsetW;
-    Value offsetH;
 
     // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
     SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+    SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
     // Descriptor shape is expected to be 2D.
     int64_t rank = mixedSizes.size();
     auto sourceTy = source.getType();
@@ -216,12 +214,11 @@ class CreateNdDescToXeVMPattern
       val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
       return val;
     };
-    // Offsets are not supported (0 is used).
-    offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
-    offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
     // Get shape values from op fold results.
     baseShapeW = createOffset(mixedSizes, 1);
     baseShapeH = createOffset(mixedSizes, 0);
+    // Get pitch value from op fold results.
+    Value basePitch = createOffset(mixedStrides, 0);
     // Populate payload.
     Value payLoadAsI64 =
         vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
@@ -235,12 +232,9 @@ class CreateNdDescToXeVMPattern
     payload =
         vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
                                  static_cast<int>(NdTdescOffset::BaseShapeH));
-    payload = vector::InsertOp::create(
-        rewriter, loc, offsetW, payload,
-        static_cast<int>(NdTdescOffset::TensorOffsetW));
-    payload = vector::InsertOp::create(
-        rewriter, loc, offsetH, payload,
-        static_cast<int>(NdTdescOffset::TensorOffsetH));
+    payload =
+        vector::InsertOp::create(rewriter, loc, basePitch, payload,
+                                 static_cast<int>(NdTdescOffset::BasePitch));
     rewriter.replaceOp(op, payload);
     return success();
   }
@@ -289,6 +283,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
           rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
       Value baseShapeH = vector::ExtractOp::create(
           rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
+      Value basePitch = vector::ExtractOp::create(
+          rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
       // Offsets are provided by the op.
       // convert them to i32.
       Value offsetW =
@@ -303,8 +299,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
       Value basePtrLLVM =
           LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
       // Compute width in bytes.
-      Value surfaceW =
+      Value baseWidthByte =
           arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+      // Compute pitch in bytes.
+      Value basePitchByte =
+          arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
 
       // Get tile width from the tensor descriptor type.
       auto tileW = tdescTy.getDimSize(tileRank - 1);
@@ -331,8 +330,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
         auto storeCacheControl =
             translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
         xevm::BlockStore2dOp::create(
-            rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
-            offsetH, elemBitSize, tileW, tileH, src,
+            rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
+            basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, src,
             xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
         rewriter.eraseOp(op);
       } else {
@@ -340,9 +339,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
             translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
         if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
           xevm::BlockPrefetch2dOp::create(
-              rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
-              offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
-              xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+              rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
+              basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
+              vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
           rewriter.eraseOp(op);
         } else {
           VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
@@ -355,9 +354,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
                              : rewriter.getIntegerType(elemBitSize));
 
           Value resultFlatVec = xevm::BlockLoad2dOp::create(
-              rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
-              surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
-              transpose, vnni,
+              rewriter, loc, loadedTy, basePtrLLVM, baseWidthByte, baseShapeH,
+              basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
+              vblocks, transpose, vnni,
               xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
           resultFlatVec = vector::BitCastOp::create(
               rewriter, loc,
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 8b87b791c9fd3..9a1e2cb3c7de0 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -8,21 +8,19 @@ gpu.module @create_nd_tdesc {
   gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
   %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
         // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
-        // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
+        // CHECK: %[[DYN_ADDR:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
         // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
         // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
         // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32
-        // CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32
         // CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32
         // CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32
+        // CHECK: %[[PITCH:.*]] = arith.index_cast %[[ARG4]] : index to i32
         // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
         // CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64>
         // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
         // CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32>
         // CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32>
+        // CHECK: %[[VAR11:.*]] = vector.insert %[[PITCH]], %[[VAR10]] [4] : i32 into vector<8xi32>
         %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
             : ui64 -> !xegpu.tensor_desc<8x16xf32>
 
@@ -32,19 +30,18 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
         // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
         // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
-        // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
         // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
         // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
         // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
         // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
+        // CHECK: %[[C32_I64_2:.*]] = arith.constant 32 : i64
+        // CHECK: %[[PITCH2:.*]] = arith.trunci %[[C32_I64_2]] : i64 to i32
         // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
         // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
         // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
         // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
         // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
+        // CHECK: %[[VAR19:.*]] = vector.insert %[[PITCH2]], %[[VAR18]] [4] : i32 into vector<8xi32>
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
 
         // CHECK: %[[C1:.*]] = arith.constant 1 : index
@@ -53,18 +50,16 @@ gpu.module @create_nd_tdesc {
         %size_x = arith.constant 64 : index
         // CHECK: %[[C16:.*]] = arith.constant 16 : index
         %BLOCK_DMODEL = arith.constant 16 : index
-        // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
-        // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
-        // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
-        // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
-        // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[VAR29:.*]] = vector.insert %[[C0_I32_6]], %[[VAR28]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[VAR30:.*]] = vector.insert %[[C0_I32_7]], %[[VAR29]] [5] : i32 into vector<8xi32>
+        // CHECK: %[[CST_3:.*]] = arith.constant dense<0> : vector<8xi32>
+        // CHECK: %[[SHAPE_W3:.*]] = arith.index_cast %[[C16]] : index to i32
+        // CHECK: %[[SHAPE_H3:.*]] = arith.index_cast %[[C64]] : index to i32
+        // CHECK: %[[PITCH3:.*]] = arith.index_cast %[[C16]] : index to i32
+        // CHECK: %[[VAR25:.*]] = vector.bitcast %[[CST_3]] : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR]], %[[VAR25]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VAR27:.*]] = vector.bitcast %[[VAR26]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[VAR28:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR27]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR28]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[VAR30:.*]] = vector.insert %[[PITCH3]], %[[VAR29]] [4] : i32 into vector<8xi32>
         %dyn_tdesc  = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16>
         gpu.return
     }
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
index afeae8be24b72..4c73c9c238b6e 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -1,78 +1,32 @@
-// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
 
 gpu.module @load_store_check {
     // CHECK-LABEL: gpu.func @load_store(
-    // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: memref<8x16xf32, 1>) kernel {
     gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+        // CHECK: %[[W_P_BYTES:.*]] = arith.constant 64 : i32
+        // CHECK: %[[ZERO:.*]] = arith.constant 0 : i32
+        // CHECK: %[[H:.*]] = arith.constant 8 : i32
         %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
         %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
 
-        // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
-        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
-        // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
-        // CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]] : memref<8x16xf32, 1> to memref<8x16xf32>
-        // CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<8x16xf32> -> index
-        // CHECK: %[[ST_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR_1]] : index to i64
-        // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 
-
-        //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
-        //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
-        //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
-        //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
-        //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
-        //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32
-        //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
-        //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
-        //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
-        //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
-        //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
+        //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]]
         //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
         //CHECK-SAME:   pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
         //CHECK-SAME:   v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
         %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
             : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
-        //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32>
 
         %tid_x = gpu.thread_id x
         %tid_x_i32 = arith.index_cast %tid_x : index to i32
         %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
-        //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
         %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
 
-        // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[DESC_0:.*]] = vector.insert %[[ST_PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
         %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
 
-        //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
-        //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
-        //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
-        //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
-        //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
-        //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32
-        //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
-        //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
-        //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
-        //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
-        //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
-        //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]]
-        //CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+        //CHECK: xevm.blockstore2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]], %{{.*}} <{
+        //CHECK-SAME:   cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
         //CHECK-SAME:   tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
         xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
             : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
diff --git a/mlir/test/Conve...
[truncated]


// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
Copy link
Contributor

@Jianhui-Li Jianhui-Li Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should differentiate the source being a memref or a pointer. For pointer, user is expected to provide both shapes and strides, so the above code works fine.
But for memref source, user may not know the stride, the code should extract the strides from memref. For dynamic shape memref, this will trigger the ExtractStridedMetadataOp again (after the one in the type conversion to get base addr and offset) but I guess it should be removed in the llvm level.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if ranked dynamic memref triggers ExtractStrideMetadataOp, lowering will clean up and allow direct access to relevant fields from lowered and decomposed memref.
See

(python-3.9) jovyan@jupyter-silee2:~/Projects/llvm-project [main|⚑ 29]$ cat strided.mlir
module {
  func.func @test(%arg0: memref<?x?xf32>) -> (index) {
    %base, %offset, %sizes:2, %strides:2 =
      memref.extract_strided_metadata %arg0 : memref<?x?xf32>
        -> memref<f32>, index, index, index, index, index
        return %strides#0 : index
  }
}
(python-3.9) jovyan@jupyter-silee2:~/Projects/llvm-project [main|⚑ 29]$ ./build/bin/mlir-opt --convert-to-llvm -canonicalize strided.mlir
module {
  llvm.func @test(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) -> i64 {
    llvm.return %arg5 : i64
  }
}

You can see that stride is forwarded directly from kernel arg, which is lowered and unpacked from memref.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am not mistaken for dynamic memrefs this still get the strides from createNd's own parameters, not memref's right?

Is that done in a seperate PR?

@Jianhui-Li Seems like this PR also still get the strides from CreateNd? So in that case is it fine to move ahead with #170218? And we can fix the whole thing (take strides using ExtractMetaOp) and remove shape, strided from createNd in a new PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The PR does not address the issue of Dynamic memrefs. Dynamic memref will be handled in a separate PR.

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants