Skip to content

Conversation

@silee2
Copy link
Contributor

@silee2 silee2 commented Dec 8, 2025

If source strided memref is not fully static - at least one of shape, strides, offset is kDynamic - use i64 source variant.
With this change, xegpu.create_nd_tdesc created by lowering from vector dialect, can rely on getMixedOffsets, getMixedSize and getMixedStrides to get relevant values.

create_nd_tdesc op creation.
If source strided memref is not fully static - at least one of
shape, strides, offset is kDynamic - use i64 source variant.
@llvmbot
Copy link
Member

llvmbot commented Dec 8, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Sang Ik Lee (silee2)

Changes

If source strided memref is not fully static - at least one of shape, strides, offset is kDynamic - use i64 source variant.


Patch is 20.16 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171216.diff

5 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (+35-5)
  • (modified) mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir (+18-8)
  • (modified) mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir (+18-8)
  • (modified) mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir (+22-10)
  • (modified) mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir (+23-12)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 079e1e2a8ac67..b8606b261b781 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -102,18 +102,48 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
                                                 xegpu::TensorDescType descType,
                                                 TypedValue<MemRefType> src) {
   MemRefType srcTy = src.getType();
+  assert(srcTy.isStrided() && "Expected strided memref type");
   auto [strides, offset] = srcTy.getStridesAndOffset();
+  bool isStatic = true;
+
+  // Memref is dynamic if any of its shape, offset or strides is dynamic.
+  if (!srcTy.hasStaticShape()) {
+    isStatic = false;
+  }
+
+  if (offset == ShapedType::kDynamic)
+    isStatic = false;
+
+  for (auto stride : strides) {
+    if (stride == ShapedType::kDynamic) {
+      isStatic = false;
+      break;
+    }
+  }
 
   xegpu::CreateNdDescOp ndDesc;
-  if (srcTy.hasStaticShape()) {
+  if (isStatic) {
     ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
   } else {
-    // In case of any dynamic shapes, source's shape and strides have to be
+    // In case of ranked dynamic memref, instead of passing on the memref,
+    // i64 base address, source's offset, shape and strides have to be
     // explicitly provided.
     auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
-    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
-                                           meta.getConstifiedMixedSizes(),
-                                           meta.getConstifiedMixedStrides());
+    auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
+        rewriter, loc, meta.getBaseBuffer());
+    auto baseAddrI64 = arith::IndexCastOp::create(
+        rewriter, loc, rewriter.getI64Type(), baseAddrIndex.getResult());
+    // Strided metadata only provides 1D offset but create_nd_desc op expect
+    // offset match the rank of source memref. Add leading zeros if rank > 1.
+    srcTy.getRank();
+    SmallVector<OpFoldResult> fullOffsets;
+    for (unsigned i = 0; i < srcTy.getRank() - 1; ++i) {
+      fullOffsets.push_back(rewriter.getI64IntegerAttr(0));
+    }
+    fullOffsets.push_back(meta.getConstifiedMixedOffset());
+    ndDesc = xegpu::CreateNdDescOp::create(
+        rewriter, loc, descType, baseAddrI64, fullOffsets,
+        meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
   }
 
   return ndDesc;
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index ae5141db16c09..867d1f20fb707 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -10,9 +10,13 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[COLLAPSED]]
-// CHECK-SAME:    memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// CHECK:       %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK-SAME:    : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// CHECK:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME:    : memref<f32> -> index
+// CHECK:       %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][%[[OFFSET1]]], shape : [32],
+// CHECK-SAME:                   strides : [1] : i64  -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
 // CHECK:       return %[[VEC]]
@@ -30,9 +34,12 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[COLLAPSED]]
-// CHECK-SAME:    memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK:       %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME:    : memref<f32> -> index
+// CHECK:       %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// CHECK-SAME:                   strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
@@ -49,8 +56,11 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
 // CHECK-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// CHECK:       {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK:       %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// CHECK:       %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// CHECK-SAME:                   strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 1a10d917623cc..09bd571951a6b 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -12,9 +12,13 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[COLLAPSED]]
-// CHECK-SAME:    memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// CHECK:       %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK-SAME:    : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// CHECK:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME:    : memref<f32> -> index
+// CHECK:       %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][%[[OFFSET1]]], shape : [32],
+// CHECK-SAME:                   strides : [1] : i64  -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
 
@@ -32,9 +36,12 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[COLLAPSED]]
-// CHECK-SAME:    memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK:       %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME:    : memref<f32> -> index
+// CHECK:       %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// CHECK-SAME:                   strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // -----
@@ -51,8 +58,11 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
 // CHECK-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// CHECK:       {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK:       %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// CHECK:       %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// CHECK-SAME:                   strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
 
 // -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 8bb272b1fe5fc..af330dced143e 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -49,9 +49,12 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
 // LOAD-ND-SAME:   %[[OFFSET:.+]]: index
 // LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[COLLAPSED]]
-// LOAD-ND-SAME:     memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND:       %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND:       %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND-SAME:    : memref<f32> -> index
+// LOAD-ND:       %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// LOAD-ND-SAME:                   strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
 // LOAD-ND-SAME:     boundary_check = false
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
@@ -148,8 +151,12 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<?x?x?xf32>,
 // LOAD-ND-SAME:   %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// LOAD-ND:        {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// LOAD-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// LOAD-ND:        %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// LOAD-ND-SAME:                    strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND-SAME:                      #xegpu.block_tdesc_attr<boundary_check = false>>
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
@@ -185,7 +192,11 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<?x8x16xf32>,
 // LOAD-ND-SAME:   %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32, strided<[16, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND:        %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND:        %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [8, 16], strides : [16, 1] :
+// LOAD-ND-SAME:                    i64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]] : vector<8x16xf32>
 
@@ -460,10 +471,11 @@ gpu.func @load_from_subview_2D(%source: memref<4096x4096xf16>, %off1: index, %of
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
 // LOAD-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // LOAD-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[SUBVIEW]]
-// LOAD-ND-SAME:     memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf16,
-// LOAD-ND-SAME:     boundary_check = false
+// LOAD-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[SUBVIEW]]
+// LOAD-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND:        %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND:        %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [256, 256], strides : [4096, 1] :
+// LOAD-ND-SAME:                    i64 -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8x16xf16>
 // LOAD-ND:        return %[[VEC]]
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index 43a1a7206e2cc..6185f8537d8e0 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -16,9 +16,13 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
 // STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[COLLAPSED]]
-// STORE-ND-SAME:    memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// STORE-ND:       %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND-SAME:    : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// STORE-ND:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND-SAME:    : memref<f32> -> index
+// STORE-ND:       %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][%[[OFFSET1]]], shape : [32],
+// STORE-ND-SAME:                   strides : [1] : i64  -> !xegpu.tensor_desc<8xf32,
 // STORE-ND-SAME:    boundary_check = false
 // STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
 
@@ -51,9 +55,12 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
 // STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[COLLAPSED]]
-// STORE-ND-SAME:    memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
+// STORE-ND:       %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND:       %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND-SAME:    : memref<f32> -> index
+// STORE-ND:       %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// STORE-ND-SAME:                   strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
 // STORE-ND-SAME:    boundary_check = false
 // STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
@@ -87,8 +94,11 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
 // STORE-ND-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// STORE-ND:       {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// STORE-ND:       %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// STORE-ND:       %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// STORE-ND-SAME:                   strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
 // STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
 
 // STORE-SCATTER-LABEL: @store_dynamic_source(
@@ -295,10 +305,11 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
 // STORE-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // STORE-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
 // STORE-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
-// STORE-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:     %[[COLLAPSED]]
-// STORE-ND-SAME:     memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
-// STORE-ND-SAME:     boundary_check = false
+// STORE-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND:        %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND:        %[[DESC:.*]] = xegpu.create_nd_tdesc %0[%[[OFFSET]]], shape : [256], strides : [1] : i64 ->
+// STORE-ND-SAME:                    !xegpu.tensor_desc<8xf16, #xegpu.block_tdesc_...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 8, 2025

@llvm/pr-subscribers-mlir

Author: Sang Ik Lee (silee2)

Changes

If source strided memref is not fully static - at least one of shape, strides, offset is kDynamic - use i64 source variant.


Patch is 20.16 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171216.diff

5 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (+35-5)
  • (modified) mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir (+18-8)
  • (modified) mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir (+18-8)
  • (modified) mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir (+22-10)
  • (modified) mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir (+23-12)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 079e1e2a8ac67..b8606b261b781 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -102,18 +102,48 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
                                                 xegpu::TensorDescType descType,
                                                 TypedValue<MemRefType> src) {
   MemRefType srcTy = src.getType();
+  assert(srcTy.isStrided() && "Expected strided memref type");
   auto [strides, offset] = srcTy.getStridesAndOffset();
+  bool isStatic = true;
+
+  // Memref is dynamic if any of its shape, offset or strides is dynamic.
+  if (!srcTy.hasStaticShape()) {
+    isStatic = false;
+  }
+
+  if (offset == ShapedType::kDynamic)
+    isStatic = false;
+
+  for (auto stride : strides) {
+    if (stride == ShapedType::kDynamic) {
+      isStatic = false;
+      break;
+    }
+  }
 
   xegpu::CreateNdDescOp ndDesc;
-  if (srcTy.hasStaticShape()) {
+  if (isStatic) {
     ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
   } else {
-    // In case of any dynamic shapes, source's shape and strides have to be
+    // In case of ranked dynamic memref, instead of passing on the memref,
+    // i64 base address, source's offset, shape and strides have to be
     // explicitly provided.
     auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
-    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
-                                           meta.getConstifiedMixedSizes(),
-                                           meta.getConstifiedMixedStrides());
+    auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
+        rewriter, loc, meta.getBaseBuffer());
+    auto baseAddrI64 = arith::IndexCastOp::create(
+        rewriter, loc, rewriter.getI64Type(), baseAddrIndex.getResult());
+    // Strided metadata only provides 1D offset but create_nd_desc op expect
+    // offset match the rank of source memref. Add leading zeros if rank > 1.
+    srcTy.getRank();
+    SmallVector<OpFoldResult> fullOffsets;
+    for (unsigned i = 0; i < srcTy.getRank() - 1; ++i) {
+      fullOffsets.push_back(rewriter.getI64IntegerAttr(0));
+    }
+    fullOffsets.push_back(meta.getConstifiedMixedOffset());
+    ndDesc = xegpu::CreateNdDescOp::create(
+        rewriter, loc, descType, baseAddrI64, fullOffsets,
+        meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
   }
 
   return ndDesc;
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index ae5141db16c09..867d1f20fb707 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -10,9 +10,13 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[COLLAPSED]]
-// CHECK-SAME:    memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// CHECK:       %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK-SAME:    : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// CHECK:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME:    : memref<f32> -> index
+// CHECK:       %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][%[[OFFSET1]]], shape : [32],
+// CHECK-SAME:                   strides : [1] : i64  -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
 // CHECK:       return %[[VEC]]
@@ -30,9 +34,12 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[COLLAPSED]]
-// CHECK-SAME:    memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK:       %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME:    : memref<f32> -> index
+// CHECK:       %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// CHECK-SAME:                   strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
@@ -49,8 +56,11 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
 // CHECK-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// CHECK:       {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK:       %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// CHECK:       %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// CHECK-SAME:                   strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 1a10d917623cc..09bd571951a6b 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -12,9 +12,13 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[COLLAPSED]]
-// CHECK-SAME:    memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// CHECK:       %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK-SAME:    : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// CHECK:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME:    : memref<f32> -> index
+// CHECK:       %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][%[[OFFSET1]]], shape : [32],
+// CHECK-SAME:                   strides : [1] : i64  -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
 
@@ -32,9 +36,12 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[COLLAPSED]]
-// CHECK-SAME:    memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK:       %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME:    : memref<f32> -> index
+// CHECK:       %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// CHECK-SAME:                   strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // -----
@@ -51,8 +58,11 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
 // CHECK-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// CHECK:       {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK:       %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// CHECK:       %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// CHECK-SAME:                   strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
 
 // -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 8bb272b1fe5fc..af330dced143e 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -49,9 +49,12 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
 // LOAD-ND-SAME:   %[[OFFSET:.+]]: index
 // LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[COLLAPSED]]
-// LOAD-ND-SAME:     memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND:       %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND:       %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND-SAME:    : memref<f32> -> index
+// LOAD-ND:       %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// LOAD-ND-SAME:                   strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
 // LOAD-ND-SAME:     boundary_check = false
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
@@ -148,8 +151,12 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<?x?x?xf32>,
 // LOAD-ND-SAME:   %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// LOAD-ND:        {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// LOAD-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// LOAD-ND:        %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// LOAD-ND-SAME:                    strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND-SAME:                      #xegpu.block_tdesc_attr<boundary_check = false>>
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
@@ -185,7 +192,11 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<?x8x16xf32>,
 // LOAD-ND-SAME:   %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32, strided<[16, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND:        %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND:        %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [8, 16], strides : [16, 1] :
+// LOAD-ND-SAME:                    i64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]] : vector<8x16xf32>
 
@@ -460,10 +471,11 @@ gpu.func @load_from_subview_2D(%source: memref<4096x4096xf16>, %off1: index, %of
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
 // LOAD-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // LOAD-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[SUBVIEW]]
-// LOAD-ND-SAME:     memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf16,
-// LOAD-ND-SAME:     boundary_check = false
+// LOAD-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[SUBVIEW]]
+// LOAD-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND:        %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND:        %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [256, 256], strides : [4096, 1] :
+// LOAD-ND-SAME:                    i64 -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8x16xf16>
 // LOAD-ND:        return %[[VEC]]
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index 43a1a7206e2cc..6185f8537d8e0 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -16,9 +16,13 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
 // STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[COLLAPSED]]
-// STORE-ND-SAME:    memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// STORE-ND:       %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND-SAME:    : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// STORE-ND:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND-SAME:    : memref<f32> -> index
+// STORE-ND:       %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][%[[OFFSET1]]], shape : [32],
+// STORE-ND-SAME:                   strides : [1] : i64  -> !xegpu.tensor_desc<8xf32,
 // STORE-ND-SAME:    boundary_check = false
 // STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
 
@@ -51,9 +55,12 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
 // STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[COLLAPSED]]
-// STORE-ND-SAME:    memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
+// STORE-ND:       %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND:       %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND-SAME:    : memref<f32> -> index
+// STORE-ND:       %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// STORE-ND-SAME:                   strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
 // STORE-ND-SAME:    boundary_check = false
 // STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
@@ -87,8 +94,11 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
 // STORE-ND-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// STORE-ND:       {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// STORE-ND:       %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// STORE-ND:       %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// STORE-ND-SAME:                   strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
 // STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
 
 // STORE-SCATTER-LABEL: @store_dynamic_source(
@@ -295,10 +305,11 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
 // STORE-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // STORE-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
 // STORE-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
-// STORE-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:     %[[COLLAPSED]]
-// STORE-ND-SAME:     memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
-// STORE-ND-SAME:     boundary_check = false
+// STORE-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND:        %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND:        %[[DESC:.*]] = xegpu.create_nd_tdesc %0[%[[OFFSET]]], shape : [256], strides : [1] : i64 ->
+// STORE-ND-SAME:                    !xegpu.tensor_desc<8xf16, #xegpu.block_tdesc_...
[truncated]

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

auto baseAddrI64 = arith::IndexCastOp::create(
rewriter, loc, rewriter.getI64Type(), baseAddrIndex.getResult());
// Strided metadata only provides 1D offset but create_nd_desc op expect
// offset match the rank of source memref. Add leading zeros if rank > 1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 1d offset needs to be added to the baseAddrI64 here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated PR to generate adjusted base addr: base addr + offset * element_size_in_bytes

bool isStatic = true;

// Memref is dynamic if any of its shape, offset or strides is dynamic.
if (!srcTy.hasStaticShape()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: braces can be skipped

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed braces.

@silee2 silee2 requested a review from fabianmcg as a code owner December 10, 2025 22:12
@silee2 silee2 removed the request for review from fabianmcg December 10, 2025 22:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants