diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 75b16a87e03c6..1a6a34c8d775a 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -29,7 +29,7 @@ class XeGPU_Op traits = []>: void printProperties(::mlir::MLIRContext *ctx, ::mlir::OpAsmPrinter &p, const Properties &prop, ::mlir::ArrayRef<::llvm::StringRef> elidedProps) { - + DictionaryAttr propAttr = dyn_cast_if_present(getPropertiesAsAttr(ctx, prop)); // filter out the elidedProps from propAttr, and get the resultAttr @@ -43,7 +43,7 @@ class XeGPU_Op traits = []>: } if (!filteredAttrs.empty()) { - p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">"; + p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">"; } } @@ -60,8 +60,7 @@ class XeGPU_Op traits = []>: } -def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, - AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> { +def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, AttrSizedOperandSegments]> { let summary = "Create nd-tensor descriptor operation"; let description = [{ @@ -181,82 +180,38 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface return getType().getShape(); } - /// wrapper for matching with OffsetSizeAndStrideOpInterface - OperandRange getSizes() { - return getShape(); + SmallVector getMixedOffsets() { + auto statics = getConstOffsets().value_or(SmallVector()); + auto dynamics = getOffsets(); + if (statics.size() == 0 && dynamics.size() == 0) + return {}; + return getMixedValues(statics, dynamics, getContext()); } - ArrayRef getStaticOffsets(){ - auto attr = getConstOffsetsAttr(); - - if (attr) - return attr; + SmallVector getMixedSizes() { + SmallVector statics; - int64_t rank = getMixedSizes().size(); - - setConstOffsets(llvm::SmallVector(rank, 0)); + /// Get the static sizes/shape, the value passed to const_shape + /// will overide the value in memref shape. + if (auto memrefTy = llvm::dyn_cast(getSourceType())) + statics = llvm::to_vector(memrefTy.getShape()); + if (auto attr = getConstShapeAttr()) + statics = llvm::to_vector(attr.asArrayRef()); - attr = getConstOffsetsAttr(); - return attr; + return getMixedValues(statics, getShape(), getContext()); } - /// wrapper for matching with OffsetSizeAndStrideOpInterface - /// If source is IntegerType or `const_shape` is filled, - /// it will return `const_shape`, such that mixes of `shape` - /// and `const_shape` will be used to represent the shape of - /// source operand. They overide static shape from source memref type. - ArrayRef getStaticSizes() { - /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks - static llvm::SmallVector emptyShape; - - auto attr = getConstShapeAttr(); - if (attr) - return attr; - - if (llvm::isa(getSourceType())) - return emptyShape; - - auto memrefType = llvm::dyn_cast(getSourceType()); - assert(memrefType && "Incorrect use of getStaticSizes"); - return memrefType.getShape(); - } + SmallVector getMixedStrides() { + SmallVector statics; - /// wrapper for matching with OffsetSizeAndStrideOpInterface - /// If source is IntegerType or `const_strides` is filled, it - /// will return `const_strides`, such that mixes of `strides` - /// and `const_strides` will be used to represent the strides of - /// source operand. They overide static strides from source memref type. - ArrayRef getStaticStrides() { - /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks - static llvm::SmallVector emptyStrides; - - auto attr = getConstStridesAttr(); - if (attr) - return attr; - - if (llvm::isa(getSourceType())) - return emptyStrides; - - auto memrefType = llvm::dyn_cast(getSourceType()); - assert(memrefType && "Incorrect use of getStaticStrides"); - auto [strides, _] = memrefType.getStridesAndOffset(); - // reuse the storage of ConstStridesAttr since strides from - // memref is not persistant - setConstStrides(strides); - attr = getConstStridesAttr(); - return attr; - } + /// Get the static strides, the value passed to const_strides + /// will overide the value in memref. + if (auto memrefTy = llvm::dyn_cast(getSourceType())) + statics = memrefTy.getStridesAndOffset().first; + if (auto attr = getConstStridesAttr()) + statics = llvm::to_vector(attr.asArrayRef()); - /// Return the expected rank of each of the`static_offsets`, - /// `static_shape` and `static_strides` attributes. - std::array getArrayAttrMaxRanks() { - unsigned rank; - if (auto ty = llvm::dyn_cast(getSourceType())) { - rank = ty.getRank(); - } else { - rank = (unsigned)getMixedOffsets().size(); - } - return {rank, rank, rank}; + return getMixedValues(statics, getStrides(), getContext()); } /// Return the number of leading operands before the `offsets`, @@ -314,15 +269,15 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> { }]; let assemblyFormat = [{ - $TensorDesc `` - custom($offsets, $const_offsets) + $TensorDesc `` + custom($offsets, $const_offsets) prop-dict attr-dict `:` qualified(type($TensorDesc)) }]; let builders = [ - OpBuilder<(ins "Value": $TensorDesc, - "xegpu::CachePolicyAttr": $l1_hint, - "xegpu::CachePolicyAttr": $l2_hint, + OpBuilder<(ins "Value": $TensorDesc, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, "xegpu::CachePolicyAttr": $l3_hint)> ]; @@ -370,7 +325,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ let arguments = (ins XeGPU_TensorDesc: $TensorDesc, Variadic: $offsets, - OptionalAttr: $const_offsets, + OptionalAttr: $const_offsets, OptionalAttr: $packed, OptionalAttr: $transpose, OptionalAttr: $l1_hint, @@ -390,16 +345,16 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ }]; let assemblyFormat = [{ - $TensorDesc `` - custom($offsets, $const_offsets) + $TensorDesc `` + custom($offsets, $const_offsets) prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value) }]; let builders = [ - OpBuilder<(ins "Type": $value, "Value": $TensorDesc, + OpBuilder<(ins "Type": $value, "Value": $TensorDesc, "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose, - "xegpu::CachePolicyAttr": $l1_hint, - "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, "xegpu::CachePolicyAttr": $l3_hint)> ]; @@ -442,7 +397,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ let arguments = (ins XeGPU_ValueType: $value, XeGPU_TensorDesc: $TensorDesc, Variadic: $offsets, - OptionalAttr: $const_offsets, + OptionalAttr: $const_offsets, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint); @@ -458,16 +413,16 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ }]; let assemblyFormat = [{ - $value `,` - $TensorDesc `` - custom($offsets, $const_offsets) + $value `,` + $TensorDesc `` + custom($offsets, $const_offsets) prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc)) }]; let builders = [ - OpBuilder<(ins "Value": $value, "Value": $TensorDesc, - "xegpu::CachePolicyAttr": $l1_hint, - "xegpu::CachePolicyAttr": $l2_hint, + OpBuilder<(ins "Value": $value, "Value": $TensorDesc, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, "xegpu::CachePolicyAttr": $l3_hint)> ]; @@ -635,12 +590,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16xf16> ``` - + Example 2: A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). - Please refer to create_tdesc for the restriction of memref. + Please refer to create_tdesc for the restriction of memref. ```mlir %a = memref.alloc() : memref<1024xf32> %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex> @@ -676,16 +631,16 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { }]; let assemblyFormat = [{ - $source + $source (`[` $offsets^ `]`)? prop-dict - attr-dict `:` type(operands) + attr-dict `:` type(operands) }]; - + let builders = [ OpBuilder<(ins "Value": $source, - "xegpu::CachePolicyAttr": $l1_hint, - "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, "xegpu::CachePolicyAttr": $l3_hint)> ]; @@ -723,7 +678,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x8xf32> ``` - + Example 3 (SIMT mode): ```mlir %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint, @@ -732,12 +687,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> vector<16xi1> -> vector<8xf32> ``` - + Example 4: A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc - for the restriction of memref. + for the restriction of memref. ```mlir %a = memref.alloc() : memref<1024xf32> %offsets = vector.step : vector<16xindex> @@ -794,14 +749,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { let assemblyFormat = [{ $source (`[` $offsets^ `]`)? `,` - $mask prop-dict + $mask prop-dict attr-dict `:` type(operands) `->` type($value) }]; let builders = [ OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask, - "xegpu::CachePolicyAttr": $l1_hint, - "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, "xegpu::CachePolicyAttr": $l3_hint)> ]; @@ -848,7 +803,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc". The dest operand could be a raw pointer (uint64_t). - Please refer to create_tdesc for the restriction of memref. + Please refer to create_tdesc for the restriction of memref. ```mlir %a = memref.alloc() : memref<1024xf32> %val = arith.constant dense<0.0> : vector<16xf32> @@ -901,15 +856,15 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { $value `,` $dest (`[` $offsets^ `]`)? `,` - $mask - prop-dict + $mask + prop-dict attr-dict `:` type(operands) }]; let builders = [ OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask, - "xegpu::CachePolicyAttr": $l1_hint, - "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, "xegpu::CachePolicyAttr": $l3_hint)> ]; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 33450f3fa229e..b519d6ad72660 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -265,8 +265,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, } LogicalResult CreateNdDescOp::verify() { - auto rank = (int64_t)getMixedOffsets().size(); - bool invalidRank = false; + size_t rank = getMixedSizes().size(); + bool invalidRank = rank != getMixedStrides().size(); bool invalidElemTy = false; // Memory space of created TensorDesc should match with the source. @@ -280,31 +280,28 @@ LogicalResult CreateNdDescOp::verify() { << " Source: " << srcMemorySpace << ", TensorDesc: " << tdescMemorySpace; + if (size_t offsetRank = getMixedOffsets().size()) + invalidRank |= (offsetRank != rank); + // check source type matches the rank if it is a memref. // It also should have the same ElementType as TensorDesc. - auto memrefTy = dyn_cast(getSourceType()); - if (memrefTy) { - invalidRank |= (memrefTy.getRank() != rank); + if (auto memrefTy = dyn_cast(getSourceType())) invalidElemTy |= memrefTy.getElementType() != getElementType(); - } if (llvm::isa(getSourceType())) { // strides and shape must present for integer source. if (getMixedStrides().empty() || getMixedSizes().empty()) - return emitOpError("Expecting strides and shape to be present for " + return emitOpError("expecting strides and shape to be present for " "integer source."); } - // mismatches among shape, strides, and offsets are - // already handeled by OffsetSizeAndStrideOpInterface. - // So they are not check here. if (invalidRank) return emitOpError( "Expecting the rank of shape, strides, offsets, and source (if source " "is a memref) should match with each other."); // check result TensorDesc rank - if (getType().getRank() > rank) + if (getType().getRank() > (int64_t)rank) return emitOpError( "Expecting the TensorDesc rank is not greater than the " "ranks of shape, strides, offsets or the memref source."); diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index dff3ffab39ecf..cdf147a9fdd0e 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -52,14 +52,14 @@ func.func @create_nd_tdesc_7(%src: memref<128x128xf32>) { // ----- func.func @create_nd_tdesc_8(%src: ui64) { - // expected-error@+1 {{'xegpu.create_nd_tdesc' op Expecting strides and shape to be present for integer source}} + // expected-error@+1 {{'xegpu.create_nd_tdesc' op expecting strides and shape to be present for integer source}} %1 = xegpu.create_nd_tdesc %src : ui64-> !xegpu.tensor_desc<128x128xf32> return } // ----- func.func @create_nd_tdesc_9(%src: ui64) { - // expected-error@+1 {{expected mixed offsets rank to match mixed sizes rank}} + // expected-error@+1 {{expecting strides and shape to be present for integer source}} %1 = xegpu.create_nd_tdesc %src[0, 0] : ui64-> !xegpu.tensor_desc<128x128xf32> return } @@ -149,7 +149,7 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) { } // ----- -func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) { +func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) { %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16> %5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> // expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}} @@ -418,7 +418,7 @@ func.func @store_scatter_offset_wi_1(%src: memref) { %offsets = arith.constant dense<[0]> : vector<1xindex> %mask = arith.constant dense<1>: vector<1xi1> // expected-error@+1 {{value elements must match chunk size}} - xegpu.store %val, %src[%offsets], %mask + xegpu.store %val, %src[%offsets], %mask : vector<4xf16>, memref, vector<1xindex>, vector<1xi1> return } @@ -429,7 +429,7 @@ func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) { %offsets = arith.constant dense<[0]> : vector<1xindex> %mask = arith.constant dense<1>: vector<1xi1> // expected-error@+1 {{Expecting the dest is a 1D memref or pointer}} - xegpu.store %val, %src[%offsets], %mask + xegpu.store %val, %src[%offsets], %mask : vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1> return } diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 6be2371d4d7b2..67c00f5a9cc2f 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -62,28 +62,28 @@ gpu.func @create_nd_tdesc_7(%src: memref<8x24x32x48x64xf32>) { } -// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>) +// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>) gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index, %src2: memref<24x32xf32>) { //CHECK: %[[C:.*]] = arith.constant 1 : index %c1 = arith.constant 1 : index - - // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + + // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> %3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - + gpu.return } -// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) +// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) gpu.func @test_create_nd_tdesc_8(%src: ui64, %w : index, %h : index, %x : index, %y : index) { - - %c1 = arith.constant 1 : index - // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32> + + %c1 = arith.constant 1 : index + // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32> %2 = xegpu.create_nd_tdesc %src, shape : [%h, %w], strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32> - + gpu.return } -// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}}) +// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}}) gpu.func @test_create_nd_tdesc_9(%src: memref, %w : index, %h : index, %x : index, %y : index) { @@ -94,10 +94,10 @@ gpu.func @test_create_nd_tdesc_9(%src: memref, %w : index, %h : index, gpu.return } -// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}}) -gpu.func @test_create_nd_tdesc_10(%src: memref, %w : index, %h : index, %x : index, %y : index) { +// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}}) +gpu.func @test_create_nd_tdesc_10(%src: memref, %w : index, %h : index, %x : index, %y : index) { %c1 = arith.constant 1 : index - // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref -> !xegpu.tensor_desc<8x16xf16> + // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref -> !xegpu.tensor_desc<8x16xf16> %2 = xegpu.create_nd_tdesc %src, shape:[%h, %w], strides:[%w, %c1] : memref -> !xegpu.tensor_desc<8x16xf16> gpu.return @@ -123,7 +123,7 @@ gpu.func @prefetch_nd_2(%src: memref<48x64xf16>) { // CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<48x64xf16>, %arg1: index, %arg2: index) { gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index) { - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]] : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16> %1 = xegpu.create_nd_tdesc %src : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16> // CHECK: xegpu.prefetch_nd %[[R0]][%arg1, %arg2] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> xegpu.prefetch_nd %1[%x, %y] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<8x16xf16> @@ -271,7 +271,7 @@ gpu.func @subgroup_load_nd_8(%src: memref<24x32xf32>) { // CHECK: func @subgroup_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>, %arg1: index, %arg2: index) { gpu.func @subgroup_load_nd_offset_1(%src: memref<24x32xf32>, %x : index, %y : index) { - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0 : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> %1 = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][%arg1, %arg2] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32> %2 = xegpu.load_nd %1[%x, %y] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32> @@ -290,7 +290,7 @@ gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) { // CHECK: func @simt_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>) { gpu.func @simt_load_nd_offset_1(%src: memref<24x32xf32>) { - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0 : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> %1 = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32> %2 = xegpu.load_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32> @@ -323,7 +323,7 @@ gpu.func @simt_store_nd(%src: memref<24x32xf16>) { gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) { // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16> %1 = arith.constant dense<1.0>: vector<32xf16> - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> %2 = xegpu.create_nd_tdesc %dst : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> // CHECK: xegpu.store_nd %[[C]], %[[R0]][%arg1] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16> xegpu.store_nd %1, %2[%x] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<32xf16>, !xegpu.tensor_desc<32xf16> @@ -356,7 +356,7 @@ gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) { gpu.func @simt_store_nd_offset_1(%src: memref<24x32xf16>) { // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16> %1 = arith.constant dense<1.0>: vector<2xf16> - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0 : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> %2 = xegpu.create_nd_tdesc %src : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<2xf16>, !xegpu.tensor_desc<32xf16> xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<2xf16>, !xegpu.tensor_desc<32xf16>