Skip to content

[mlir][xegpu] Remove OffsetSizeAndStrideOpInterface from CreateNdDescOp #152773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 8, 2025

Conversation

chencha3
Copy link
Contributor

@chencha3 chencha3 commented Aug 8, 2025

As XeGPU design is moving offsets out of CreateNdDescOp. OffsetSizeAndStrideOpInterface is no longer meaningful in such context.

@chencha3 chencha3 marked this pull request as ready for review August 8, 2025 18:20
@llvmbot
Copy link
Member

llvmbot commented Aug 8, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Chao Chen (chencha3)

Changes

As XeGPU design is moving offsets out of CreateNdDescOp. OffsetSizeAndStrideOpInterface is no longer meaningful in such context.


Patch is 24.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152773.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+63-108)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+5-2)
  • (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+5-5)
  • (modified) mlir/test/Dialect/XeGPU/ops.mlir (+18-18)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 75b16a87e03c6..1a6a34c8d775a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -29,7 +29,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
     void printProperties(::mlir::MLIRContext *ctx,
             ::mlir::OpAsmPrinter &p, const Properties &prop,
             ::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
-      
+
       DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));
 
       // filter out the elidedProps from propAttr, and get the resultAttr
@@ -43,7 +43,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
       }
 
       if (!filteredAttrs.empty()) {
-        p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">"; 
+        p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
       }
     }
 
@@ -60,8 +60,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
 }
 
 
-def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface,
-                        AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> {
+def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, AttrSizedOperandSegments]> {
 
   let summary = "Create nd-tensor descriptor operation";
   let description = [{
@@ -181,82 +180,38 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
       return getType().getShape();
     }
 
-    /// wrapper for matching with OffsetSizeAndStrideOpInterface
-    OperandRange getSizes() {
-      return getShape();
+    SmallVector<OpFoldResult> getMixedOffsets() {
+      auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
+      auto dynamics = getOffsets();
+      if (statics.size() == 0 && dynamics.size() == 0)
+        return {};
+      return getMixedValues(statics, dynamics, getContext());
     }
 
-    ArrayRef<int64_t> getStaticOffsets(){
-      auto attr = getConstOffsetsAttr();
-
-      if (attr) 
-        return attr;
+    SmallVector<OpFoldResult> getMixedSizes() {
+      SmallVector<int64_t> statics;
 
-      int64_t rank = getMixedSizes().size();
-      
-      setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, 0));
+      /// Get the static sizes/shape, the value passed to const_shape
+      /// will overide the value in memref shape.
+      if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
+        statics = llvm::to_vector(memrefTy.getShape());
+      if (auto attr = getConstShapeAttr())
+        statics = llvm::to_vector(attr.asArrayRef());
 
-      attr = getConstOffsetsAttr();
-      return attr;
+      return getMixedValues(statics, getShape(), getContext());
     }
 
-    /// wrapper for matching with OffsetSizeAndStrideOpInterface
-    /// If source is IntegerType or `const_shape` is filled,
-    /// it will return `const_shape`, such that mixes of `shape`
-    /// and `const_shape` will be used to represent the shape of
-    /// source operand. They overide static shape from source memref type.
-    ArrayRef<int64_t> getStaticSizes() {
-      /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
-      static  llvm::SmallVector<int64_t, 4> emptyShape;
-
-      auto attr = getConstShapeAttr();
-      if (attr)
-        return attr;
-
-      if (llvm::isa<IntegerType>(getSourceType()))
-        return emptyShape;
-
-      auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
-      assert(memrefType && "Incorrect use of getStaticSizes");
-      return memrefType.getShape();
-    }
+    SmallVector<OpFoldResult> getMixedStrides() {
+      SmallVector<int64_t> statics;
 
-    /// wrapper for matching with OffsetSizeAndStrideOpInterface
-    /// If source is IntegerType or `const_strides` is filled, it
-    /// will return `const_strides`, such that mixes of `strides`
-    /// and `const_strides` will be used to represent the strides of
-    /// source operand. They overide static strides from source memref type.
-    ArrayRef<int64_t> getStaticStrides() {
-      /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
-      static llvm::SmallVector<int64_t, 4> emptyStrides;
-
-      auto attr = getConstStridesAttr();
-      if (attr)
-        return attr;
-      
-      if (llvm::isa<IntegerType>(getSourceType()))
-        return emptyStrides;
-
-      auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
-      assert(memrefType && "Incorrect use of getStaticStrides");
-      auto [strides, _] = memrefType.getStridesAndOffset();
-      // reuse the storage of ConstStridesAttr since strides from
-      // memref is not persistant
-      setConstStrides(strides);
-      attr = getConstStridesAttr();
-      return attr;
-    }
+      /// Get the static strides, the value passed to const_strides
+      /// will overide the value in memref.
+      if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
+        statics = memrefTy.getStridesAndOffset().first;
+      if (auto attr = getConstStridesAttr())
+        statics = llvm::to_vector(attr.asArrayRef());
 
-    /// Return the expected rank of each of the`static_offsets`,
-    /// `static_shape` and `static_strides` attributes.
-    std::array<unsigned, 3> getArrayAttrMaxRanks() {
-      unsigned rank;
-      if (auto ty = llvm::dyn_cast<MemRefType>(getSourceType())) {
-        rank = ty.getRank();
-      } else {
-        rank = (unsigned)getMixedOffsets().size();
-      }
-      return {rank, rank, rank};
+      return getMixedValues(statics, getStrides(), getContext());
     }
 
     /// Return the number of leading operands before the `offsets`,
@@ -314,15 +269,15 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
   }];
 
   let assemblyFormat = [{
-    $TensorDesc `` 
-    custom<OptionalDynamicIndexList>($offsets, $const_offsets) 
+    $TensorDesc ``
+    custom<OptionalDynamicIndexList>($offsets, $const_offsets)
     prop-dict attr-dict `:` qualified(type($TensorDesc))
   }];
 
   let builders = [
-    OpBuilder<(ins "Value": $TensorDesc, 
-                   "xegpu::CachePolicyAttr": $l1_hint, 
-                   "xegpu::CachePolicyAttr": $l2_hint, 
+    OpBuilder<(ins "Value": $TensorDesc,
+                   "xegpu::CachePolicyAttr": $l1_hint,
+                   "xegpu::CachePolicyAttr": $l2_hint,
                    "xegpu::CachePolicyAttr": $l3_hint)>
   ];
 
@@ -370,7 +325,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
                        Variadic<Index>: $offsets,
-                       OptionalAttr<DenseI64ArrayAttr>: $const_offsets,  
+                       OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
                        OptionalAttr<UnitAttr>: $packed,
                        OptionalAttr<DenseI64ArrayAttr>: $transpose,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -390,16 +345,16 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
   }];
 
   let assemblyFormat = [{
-    $TensorDesc `` 
-    custom<OptionalDynamicIndexList>($offsets, $const_offsets) 
+    $TensorDesc ``
+    custom<OptionalDynamicIndexList>($offsets, $const_offsets)
     prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
   }];
 
   let builders = [
-    OpBuilder<(ins "Type": $value, "Value": $TensorDesc, 
+    OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
                     "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
-                    "xegpu::CachePolicyAttr": $l1_hint, 
-                    "xegpu::CachePolicyAttr": $l2_hint, 
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint)>
   ];
 
@@ -442,7 +397,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
   let arguments = (ins XeGPU_ValueType: $value,
                        XeGPU_TensorDesc: $TensorDesc,
                        Variadic<Index>: $offsets,
-                       OptionalAttr<DenseI64ArrayAttr>: $const_offsets,  
+                       OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -458,16 +413,16 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
   }];
 
    let assemblyFormat = [{
-    $value `,` 
-    $TensorDesc `` 
-    custom<OptionalDynamicIndexList>($offsets, $const_offsets) 
+    $value `,`
+    $TensorDesc ``
+    custom<OptionalDynamicIndexList>($offsets, $const_offsets)
     prop-dict attr-dict `:`  type($value) `,` qualified(type($TensorDesc))
   }];
 
   let builders = [
-    OpBuilder<(ins "Value": $value, "Value": $TensorDesc, 
-                   "xegpu::CachePolicyAttr": $l1_hint, 
-                   "xegpu::CachePolicyAttr": $l2_hint, 
+    OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
+                   "xegpu::CachePolicyAttr": $l1_hint,
+                   "xegpu::CachePolicyAttr": $l2_hint,
                    "xegpu::CachePolicyAttr": $l3_hint)>
   ];
 
@@ -635,12 +590,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
                              l3_hint = #xegpu.cache_hint<cached>}
         : !xegpu.tensor_desc<16xf16>
     ```
-    
+
     Example 2:
     A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
     It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
     The source operand could be a raw pointer (uint64_t).
-    Please refer to create_tdesc for the restriction of memref. 
+    Please refer to create_tdesc for the restriction of memref.
     ```mlir
       %a = memref.alloc() : memref<1024xf32>
       %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
@@ -676,16 +631,16 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
   }];
 
   let assemblyFormat = [{
-    $source 
+    $source
     (`[` $offsets^ `]`)?
     prop-dict
-    attr-dict `:` type(operands) 
+    attr-dict `:` type(operands)
   }];
-    
+
   let builders = [
     OpBuilder<(ins "Value": $source,
-                    "xegpu::CachePolicyAttr": $l1_hint, 
-                    "xegpu::CachePolicyAttr": $l2_hint, 
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint)>
   ];
 
@@ -723,7 +678,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
           : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
             vector<16xi1> -> vector<16x8xf32>
   ```
-  
+
   Example 3 (SIMT mode):
   ```mlir
     %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -732,12 +687,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
           : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
             vector<16xi1> -> vector<8xf32>
   ```
-  
+
   Example 4:
   A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
   It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
   The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
-  for the restriction of memref. 
+  for the restriction of memref.
   ```mlir
     %a = memref.alloc() : memref<1024xf32>
     %offsets = vector.step : vector<16xindex>
@@ -794,14 +749,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
   let assemblyFormat = [{
     $source
     (`[` $offsets^ `]`)? `,`
-    $mask prop-dict 
+    $mask prop-dict
     attr-dict `:` type(operands) `->` type($value)
   }];
 
   let builders = [
     OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
-                    "xegpu::CachePolicyAttr": $l1_hint, 
-                    "xegpu::CachePolicyAttr": $l2_hint, 
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint)>
    ];
 
@@ -848,7 +803,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
   A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
   It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
   The dest operand could be a raw pointer (uint64_t).
-  Please refer to create_tdesc for the restriction of memref. 
+  Please refer to create_tdesc for the restriction of memref.
   ```mlir
     %a = memref.alloc() : memref<1024xf32>
     %val = arith.constant dense<0.0> : vector<16xf32>
@@ -901,15 +856,15 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
     $value `,`
     $dest
     (`[` $offsets^ `]`)? `,`
-    $mask 
-    prop-dict 
+    $mask
+    prop-dict
     attr-dict `:`  type(operands)
   }];
 
   let builders = [
     OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
-                    "xegpu::CachePolicyAttr": $l1_hint, 
-                    "xegpu::CachePolicyAttr": $l2_hint, 
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint)>
    ];
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 33450f3fa229e..7ac885d2ed40f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -265,7 +265,7 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
 }
 
 LogicalResult CreateNdDescOp::verify() {
-  auto rank = (int64_t)getMixedOffsets().size();
+  int64_t rank = getMixedSizes().size();
   bool invalidRank = false;
   bool invalidElemTy = false;
 
@@ -280,6 +280,9 @@ LogicalResult CreateNdDescOp::verify() {
            << " Source: " << srcMemorySpace
            << ", TensorDesc: " << tdescMemorySpace;
 
+  if (int64_t offsetRank = getMixedOffsets().size())
+    invalidRank |= (offsetRank != rank);
+
   // check source type matches the rank if it is a memref.
   // It also should have the same ElementType as TensorDesc.
   auto memrefTy = dyn_cast<MemRefType>(getSourceType());
@@ -291,7 +294,7 @@ LogicalResult CreateNdDescOp::verify() {
   if (llvm::isa<IntegerType>(getSourceType())) {
     // strides and shape must present for integer source.
     if (getMixedStrides().empty() || getMixedSizes().empty())
-      return emitOpError("Expecting strides and shape to be present for "
+      return emitOpError("expecting strides and shape to be present for "
                          "integer source.");
   }
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index dff3ffab39ecf..cdf147a9fdd0e 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -52,14 +52,14 @@ func.func @create_nd_tdesc_7(%src: memref<128x128xf32>) {
 
 // -----
 func.func @create_nd_tdesc_8(%src: ui64) {
-  // expected-error@+1 {{'xegpu.create_nd_tdesc' op Expecting strides and shape to be present for integer source}}
+  // expected-error@+1 {{'xegpu.create_nd_tdesc' op expecting strides and shape to be present for integer source}}
   %1 = xegpu.create_nd_tdesc %src : ui64-> !xegpu.tensor_desc<128x128xf32>
   return
 }
 
 // -----
 func.func @create_nd_tdesc_9(%src: ui64) {
-  // expected-error@+1 {{expected mixed offsets rank to match mixed sizes rank}}
+  // expected-error@+1 {{expecting strides and shape to be present for integer source}}
   %1 = xegpu.create_nd_tdesc %src[0, 0] : ui64-> !xegpu.tensor_desc<128x128xf32>
   return
 }
@@ -149,7 +149,7 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
 }
 
 // -----
-func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {  
+func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
   %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
     // expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}}
@@ -418,7 +418,7 @@ func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
   %offsets = arith.constant dense<[0]> : vector<1xindex>
   %mask = arith.constant dense<1>: vector<1xi1>
   // expected-error@+1 {{value elements must match chunk size}}
-  xegpu.store %val, %src[%offsets], %mask 
+  xegpu.store %val, %src[%offsets], %mask
         : vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
   return
 }
@@ -429,7 +429,7 @@ func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
   %offsets = arith.constant dense<[0]> : vector<1xindex>
   %mask = arith.constant dense<1>: vector<1xi1>
   // expected-error@+1 {{Expecting the dest is a 1D memref or pointer}}
-  xegpu.store %val, %src[%offsets], %mask 
+  xegpu.store %val, %src[%offsets], %mask
         : vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1>
   return
 }
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 6be2371d4d7b2..67c00f5a9cc2f 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -62,28 +62,28 @@ gpu.func @create_nd_tdesc_7(%src: memref<8x24x32x48x64xf32>) {
 }
 
 
-// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>) 
+// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>)
 gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index, %src2: memref<24x32xf32>) {
   //CHECK: %[[C:.*]] = arith.constant 1 : index
   %c1 = arith.constant 1 : index
-  
-  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   %3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
- 
+
   gpu.return
 }
 
-// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) 
+// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index)
 gpu.func @test_create_nd_tdesc_8(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
-  
-  %c1 = arith.constant 1 : index   
-  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
   %2 = xegpu.create_nd_tdesc %src, shape : [%h, %w], strides : [%w, %c1]  : ui64 -> !xegpu.tensor_desc<8x16xf32>
- 
+
   gpu.return
 }
 
-// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}}) 
+// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}})
 
 gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
 
@@ -94,10 +94,10 @@ gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index,
   gpu.return
 }
 
-// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}}) 
-gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {  
+// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}})
+gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
   %c1 = arith.constant 1 : index
-  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16> 
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref<?x?xf16> -...
[truncated]

@@ -60,8 +60,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
}


def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface,
AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> {
def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, AttrSizedOperandSegments]> {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is op description section aligned with the changes now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change only impacts the implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the description, marking offsets deprecated.

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@chencha3 chencha3 merged commit fab2b22 into llvm:main Aug 8, 2025
9 checks passed
@chencha3 chencha3 deleted the refine_create_nd branch August 8, 2025 21:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants