Skip to content

Conversation

@dchigarev
Copy link
Contributor

@dchigarev dchigarev commented Oct 6, 2025

Changes the VectorToXeGPU pass to generate xegpu.load_nd/store_nd ops using new syntax with where offsets are specified at the load/store ops level.

// from this
%desc = xegpu.create_nd_tdesc %src[%off1, %off2]: memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%res = xegpu.load_nd %desc : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>

// to this
%desc = xegpu.create_nd_tdesc %src: memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%res = xegpu.load_nd %desc[%off1, %off2] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>

In order to support cases with dimension reduction at the create_nd_tdesc level (e.g. memref<8x8x16xf16> -> tensor_desc<8x16xf16> I had to tweak the load_nd/store_nd/prefetch_nd verification process to allow the number offsets to missmatch with the tensor descriptor dimension.

Why we need to change that?
// reduce dim and apply all 3 offsets at load_nd
%desc = xegpu.create_nd_tdesc %source : memref<8x16x32xf32> -> !xegpu.tensor_desc<16x32xf32>
// error: xegpu.load_nd len(offsets) != desc.rank
%res = xegpu.load_nd %desc[%off, %off, %off] : !xegpu.tensor_desc<16x32xf32> -> vector<8x16xf32>

The new verification logic checks that the number of offsets either matches the source (e.g. memref's) rank or the tensor descriptor rank. Since TensorDescriptorType doesn't carry source shape information (only the shape of the tile) we can only run the verification if the create_nd_tensordesc operation is accessible via tdesc.getDefiningOp(). If it's not the case we skip the verification and postpone it until xegpu-to-xevm pass where create_nd_tensordesc must be available.

@dchigarev dchigarev force-pushed the dchigarev/vector-to-xe-new-offs branch from 2e8e6d7 to 8581183 Compare October 16, 2025 10:47
Signed-off-by: dchigarev <[email protected]>
@dchigarev dchigarev marked this pull request as ready for review October 16, 2025 12:35
@llvmbot
Copy link
Member

llvmbot commented Oct 16, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Dmitry Chigarev (dchigarev)

Changes

Changes the VectorToXeGPU pass to generate xegpu.load_nd/store_nd ops using new syntax with where offsets are specified at the load/store ops level.

// from this
%desc = xegpu.create_nd_tdesc %src[%off1, %off2]: memref&lt;8x16xf16&gt; -&gt; !xegpu.tensor_desc&lt;8x16xf16&gt;
%res = xegpu.load_nd %desc : !xegpu.tensor_desc&lt;8x16xf16&gt; -&gt; vector&lt;8x16xf16&gt;

// to this
%desc = xegpu.create_nd_tdesc %src: memref&lt;8x16xf16&gt; -&gt; !xegpu.tensor_desc&lt;8x16xf16&gt;
%res = xegpu.load_nd %desc[%off1, %off2] : !xegpu.tensor_desc&lt;8x16xf16&gt; -&gt; vector&lt;8x16xf16&gt;

In order to support cases with dimension reduction at the create_nd_tdesc level (e.g. memref&lt;8x8x16xf16&gt; -&gt; tensor_desc&lt;8x16xf16&gt; I had to tweak the load_nd/store_nd/prefetch_nd verification process to allow the number offsets to missmatch with the tensor descriptor dimension.

<details><summary>Why we need to change that?</summary>

// reduce dim and apply all 3 offsets at load_nd
%desc = xegpu.create_nd_tdesc %source : memref&lt;8x16x32xf32&gt; -&gt; !xegpu.tensor_desc&lt;16x32xf32&gt;
// error: xegpu.load_nd len(offsets) != desc.rank
%res = xegpu.load_nd %desc[%off, %off, %off] : !xegpu.tensor_desc&lt;16x32xf32&gt; -&gt; vector&lt;8x16xf32&gt;

</details>

The new verification logic checks that the number of offsets either matches the source (e.g. memref's) rank or the tensor descriptor rank. Since TensorDescriptorType doesn't carry source shape information (only the shape of the tile) we can only run the verification if the create_nd_tensordesc operation is accessible via tdesc.getDefiningOp(). If it's not the case we skip the verification and postpone it until xegpu-to-xevm pass where create_nd_tensordesc <b>must</b> be available.


Patch is 31.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162095.diff

7 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (+46-51)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+50-34)
  • (modified) mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir (+8-8)
  • (modified) mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir (+8-8)
  • (modified) mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir (+14-14)
  • (modified) mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir (+10-10)
  • (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+3-3)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index e2c7d803e5a5e..7f11d427191e5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -97,18 +97,17 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
   return success();
 }
 
-static xegpu::CreateNdDescOp
-createNdDescriptor(PatternRewriter &rewriter, Location loc,
-                   xegpu::TensorDescType descType, TypedValue<MemRefType> src,
-                   Operation::operand_range offsets) {
+static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
+                                                Location loc,
+                                                xegpu::TensorDescType descType,
+                                                TypedValue<MemRefType> src) {
   MemRefType srcTy = src.getType();
   auto [strides, offset] = srcTy.getStridesAndOffset();
 
   xegpu::CreateNdDescOp ndDesc;
-  if (srcTy.hasStaticShape()) {
-    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
-                                           getAsOpFoldResult(offsets));
-  } else {
+  if (srcTy.hasStaticShape())
+    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
+  else {
     // In case of any dynamic shapes, source's shape and strides have to be
     // explicitly provided.
     SmallVector<Value> sourceDims;
@@ -116,38 +115,31 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
     for (unsigned i = 0; i < srcRank; ++i)
       sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
 
-    SmallVector<int64_t> constOffsets;
-    SmallVector<Value> dynOffsets;
-    for (Value offset : offsets) {
-      std::optional<int64_t> staticVal = getConstantIntValue(offset);
-      if (!staticVal)
-        dynOffsets.push_back(offset);
-      constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
-    }
-
-    SmallVector<Value> dynShapes;
+    SmallVector<OpFoldResult> mixedShapes;
     for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
       if (shape == ShapedType::kDynamic)
-        dynShapes.push_back(sourceDims[idx]);
+        mixedShapes.push_back(sourceDims[idx]);
+      else
+        mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
     }
 
     // Compute strides in reverse order.
-    SmallVector<Value> dynStrides;
+    SmallVector<OpFoldResult> mixedStrides;
     Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
     // Last stride is guaranteed to be static and unit.
+    mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
     for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
       accStride =
           arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
       if (strides[i] == ShapedType::kDynamic)
-        dynStrides.push_back(accStride);
+        mixedStrides.push_back(accStride);
+      else
+        mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
     }
-    std::reverse(dynStrides.begin(), dynStrides.end());
+    std::reverse(mixedStrides.begin(), mixedStrides.end());
 
-    ndDesc = xegpu::CreateNdDescOp::create(
-        rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
-        DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
-        DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
-        DenseI64ArrayAttr::get(rewriter.getContext(), strides));
+    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
+                                           mixedShapes, mixedStrides);
   }
 
   return ndDesc;
@@ -523,21 +515,21 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
         descShape, elementType, /*array_length=*/1,
         /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
 
-    xegpu::CreateNdDescOp ndDesc =
-        createNdDescriptor(rewriter, loc, descType,
-                           dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
-                           readOp.getIndices());
-
     DenseI64ArrayAttr transposeAttr =
         !isTransposeLoad ? nullptr
                          : DenseI64ArrayAttr::get(rewriter.getContext(),
                                                   ArrayRef<int64_t>{1, 0});
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
-                                          /*packed=*/nullptr, transposeAttr,
-                                          /*l1_hint=*/hint,
-                                          /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::CreateNdDescOp ndDesc =
+        createNdDescriptor(rewriter, loc, descType,
+                           dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
+
+    auto loadOp = xegpu::LoadNdOp::create(
+        rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(readOp.getIndices()),
+        /*packed=*/nullptr, transposeAttr,
+        /*l1_hint=*/hint,
+        /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(readOp, loadOp);
 
     return success();
@@ -579,15 +571,15 @@ struct TransferWriteLowering
         vecTy.getShape(), vecTy.getElementType(),
         /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
         xegpu::MemorySpace::Global);
+    // By default, no specific caching policy is assigned.
+    xegpu::CachePolicyAttr hint = nullptr;
     xegpu::CreateNdDescOp ndDesc =
         createNdDescriptor(rewriter, loc, descType,
-                           dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
-                           writeOp.getIndices());
+                           dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
 
-    // By default, no specific caching policy is assigned.
-    xegpu::CachePolicyAttr hint = nullptr;
     auto storeOp =
         xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
+                                 getAsOpFoldResult(writeOp.getIndices()),
                                  /*l1_hint=*/hint,
                                  /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(writeOp, storeOp);
@@ -674,17 +666,18 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
 
     // Boundary check is available only for block instructions.
     bool boundaryCheck = vecTy.getRank() > 1;
+    // By default, no specific caching policy is assigned.
+    xegpu::CachePolicyAttr hint = nullptr;
 
     auto descType = xegpu::TensorDescType::get(
         vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
         boundaryCheck, xegpu::MemorySpace::Global);
-    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-        rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
 
-    // By default, no specific caching policy is assigned.
-    xegpu::CachePolicyAttr hint = nullptr;
+    xegpu::CreateNdDescOp ndDesc =
+        createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
     auto loadNdOp = xegpu::LoadNdOp::create(
-        rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
+        rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
+        /*packed=*/nullptr, /*transpose=*/nullptr,
         /*l1_hint=*/hint,
         /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(loadOp, loadNdOp);
@@ -711,15 +704,17 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
     auto descType = xegpu::TensorDescType::get(
         vecTy.getShape(), vecTy.getElementType(),
         /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
-    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-        rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
 
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    auto storeNdOp =
-        xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
-                                 /*l1_hint=*/hint,
-                                 /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::CreateNdDescOp ndDesc =
+        createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
+
+    auto storeNdOp = xegpu::StoreNdOp::create(
+        rewriter, loc, vector, ndDesc, getAsOpFoldResult(storeOp.getIndices()),
+        /*l1_hint=*/hint,
+        /*l2_hint=*/hint, /*l3_hint=*/hint);
+
     rewriter.replaceOp(storeOp, storeNdOp);
 
     return success();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index e0a8ac40648e0..e09a084ac7ad2 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -121,6 +121,39 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
   return success();
 }
 
+// Verify that number of offsets matches either the source rank or the tdesc
+// rank.
+static LogicalResult
+isValidNdOffset(TypedValue<TensorDescType> tDesc,
+                std::optional<llvm::ArrayRef<int64_t>> constOffsets,
+                int64_t offsetSize,
+                function_ref<InFlightDiagnostic()> emitError) {
+  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
+    // If CreateNdDescOp is available, we can further
+    // check the offsets rank against the source rank.
+    auto staticSource = createTDescOp.getConstShapeAttr();
+    int64_t sourceRank;
+    if (!staticSource || staticSource.empty()) {
+      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
+      sourceRank = sourceTy.getRank();
+    } else
+      sourceRank = staticSource.size();
+
+    int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
+    auto tDescRank = tDesc.getType().getRank();
+    bool sourceRankMismatch =
+        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
+    bool tdescRankMismatch =
+        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
+    if (sourceRankMismatch && tdescRankMismatch)
+      return emitError() << "Offsets rank must match either the source or the "
+                            "TensorDesc rank.";
+  }
+  return success();
+}
+
 static LogicalResult
 isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
                                  VectorType valueTy, int64_t chunkSize,
@@ -215,8 +248,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
     auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
 
     // if shape and strides are from Memref, we don't need attributes for them
-    // to keep the IR print clean.
-    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+    // to keep the IR print clean (only do so for full-static case, otherwise
+    // printer would fail trying to print empty array-attr).
+    if (staticShape == memrefShape && staticStrides == memrefStrides &&
+        dynamicShape.empty() && dynamicStrides.empty()) {
       staticShapeAttr = DenseI64ArrayAttr();
       staticStridesAttr = DenseI64ArrayAttr();
     }
@@ -277,8 +312,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
     auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
 
     // if shape and strides are from Memref, we don't need attributes for them
-    // to keep the IR print clean.
-    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+    // to keep the IR print clean (only do so for full-static case, otherwise
+    // printer would fail trying to print empty array-attr).
+    if (staticShape == memrefShape && staticStrides == memrefStrides &&
+        dynamicShape.empty() && dynamicStrides.empty()) {
       staticShapeAttr = DenseI64ArrayAttr();
       staticStridesAttr = DenseI64ArrayAttr();
     }
@@ -428,16 +465,9 @@ LogicalResult PrefetchNdOp::verify() {
   if (!isReadHintOrNone(getL3HintAttr()))
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
-  int64_t tDescRank = tdescTy.getRank();
-  int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
-  int64_t constOffsetSize =
-      getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
-  if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
-      ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
-    return emitOpError(
-        "Mismatched ranks between offsets and tensor descriptor");
-
-  return success();
+  auto tDesc = getTensorDesc();
+  return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
+                         [&]() { return emitOpError(); });
 }
 
 //===----------------------------------------------------------------------===//
@@ -553,16 +583,9 @@ LogicalResult LoadNdOp::verify() {
                          << " is not consistent with tensor descriptor "
                          << tdescTy;
 
-  int64_t tDescRank = tdescTy.getRank();
-  int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
-  int64_t constOffsetSize =
-      getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
-  if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
-      ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
-    return emitOpError(
-        "Mismatched ranks between offsets and tensor descriptor");
-
-  return success();
+  auto tDesc = getTensorDesc();
+  return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
+                         [&]() { return emitOpError(); });
 }
 
 //===----------------------------------------------------------------------===//
@@ -647,16 +670,9 @@ LogicalResult StoreNdOp::verify() {
                          << " is not consistent with tensor descriptor "
                          << dstTy;
 
-  int64_t tDescRank = dstTy.getRank();
-  int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
-  int64_t constOffsetSize =
-      getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
-  if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
-      ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
-    return emitOpError(
-        "Mismatched ranks between offsets and tensor descriptor");
-
-  return success();
+  auto tDesc = getTensorDesc();
+  return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
+                         [&]() { return emitOpError(); });
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index 9908205f07c92..b5fb2c4aa3e27 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -10,10 +10,10 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
@@ -29,9 +29,9 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
@@ -53,10 +53,10 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
 // CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
 // CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
@@ -72,9 +72,9 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<7x15xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 2c498dcc2a071..57e754f7d7c00 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -12,10 +12,10 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8xf32>
 
 // -----
 
@@ -31,9 +31,9 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // -----
 
@@ -55,10 +55,10 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
 // CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
 // CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // -----
 
@@ -74,9 +74,9 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<7x64xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[...
[truncated]

@dchigarev
Copy link
Contributor Author

cc @Jianhui-Li @adam-smnk

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with some comments

return success();
}

// Verify that number of offsets matches either the source rank or the tdesc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should simply check that offsets can't have lower rank than tdesc rank. It shoud be fine if it is larger than tdesc rank.

I am not sure that it is common practice for the op validation to validate itself using information from producer op. I would rather checking this in transformation or lowering passes. Any opinion from @adam-smnk ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Op verifier should be contained to the op itself without accessing any external data.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should simply check that offsets can't have lower rank than tdesc rank. It shoud be fine if it is larger than tdesc rank.

Sound reasonable. Just please clearly define how offsets are interpreted.
For example, memref is 4D, descriptor is 2D and there are 3 offsets. I guess they'd apply to the innermost dimensions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplified the check to not rely on CreateNdTdescOp.

Also added docs for the new offset syntax

@adam-smnk
Copy link
Contributor

Looking at the older update_nd_offset semantics, the update is also restricted to the descriptor's rank.
I'd expect the new offset syntax to essentially fold update_nd_offset into load_nd while leaving descriptor creation as is.
If you want to fully avoid offsets at create_nd_tdesc, then an extra memref.subview could be added to create initial tile pointer.

Is there any clear benefit in relaxing load offset semantics?

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(can be a separate PR)

Could you also update ops' docs with the new offset semantics?

return success();
}

// Verify that number of offsets matches either the source rank or the tdesc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should simply check that offsets can't have lower rank than tdesc rank. It shoud be fine if it is larger than tdesc rank.

Sound reasonable. Just please clearly define how offsets are interpreted.
For example, memref is 4D, descriptor is 2D and there are 3 offsets. I guess they'd apply to the innermost dimensions.

@github-actions
Copy link

github-actions bot commented Oct 21, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Signed-off-by: dchigarev <[email protected]>
Comment on lines -138 to -142
// expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}}
%2 = xegpu.load_nd %1[0, 0] : !xegpu.tensor_desc<16xf16> -> vector<16xf16>
return
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jianhui-Li
Copy link
Contributor

Looking at the older update_nd_offset semantics, the update is also restricted to the descriptor's rank. I'd expect the new offset syntax to essentially fold update_nd_offset into load_nd while leaving descriptor creation as is. If you want to fully avoid offsets at create_nd_tdesc, then an extra memref.subview could be added to create initial tile pointer.

Is there any clear benefit in relaxing load offset semantics?

Actually this is a great question. The original motivation is that the users want to express a 2D block in the nD tensor with the higher dimensions (n>=2) flattened to the 2nd dimension. After that user only works with the flattened 2D tensor and moves the 2D block within the flatten 2D tensor.

So I think that the right solution is what you proposed: As we move the offsets to load_nd, XeGPU user must explicitly insert a memref.subview to flatten the ND tensor to 2D, and then create a 2D tensor descriptor.

Allowing 3D+ offsets creates an issue during lowering to XeVM, even we allow it in IR creation time. The tensor descriptor in HW only tracks the stride of innermost dimension so only support 2D block load. Allowing 3D+ offset means we need to either compute the flatten 2D offsets inside the K-loop, or expand the tensor descriptor to track more than 1 stride. This increases complexity and may cause negative performance impact. Instead, I think the right balance between "ease of use" and "performance" is to ask user to do a subview upfront to flatten it upfront.

See the discussion here #164701

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants