-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][xegpu] Remove OffsetSizeAndStrideOpInterface from CreateNdDescOp #152773
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Chao Chen (chencha3) ChangesAs XeGPU design is moving Patch is 24.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152773.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 75b16a87e03c6..1a6a34c8d775a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -29,7 +29,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
void printProperties(::mlir::MLIRContext *ctx,
::mlir::OpAsmPrinter &p, const Properties &prop,
::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
-
+
DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));
// filter out the elidedProps from propAttr, and get the resultAttr
@@ -43,7 +43,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
}
if (!filteredAttrs.empty()) {
- p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
+ p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
}
}
@@ -60,8 +60,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
}
-def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface,
- AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> {
+def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, AttrSizedOperandSegments]> {
let summary = "Create nd-tensor descriptor operation";
let description = [{
@@ -181,82 +180,38 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
return getType().getShape();
}
- /// wrapper for matching with OffsetSizeAndStrideOpInterface
- OperandRange getSizes() {
- return getShape();
+ SmallVector<OpFoldResult> getMixedOffsets() {
+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
+ auto dynamics = getOffsets();
+ if (statics.size() == 0 && dynamics.size() == 0)
+ return {};
+ return getMixedValues(statics, dynamics, getContext());
}
- ArrayRef<int64_t> getStaticOffsets(){
- auto attr = getConstOffsetsAttr();
-
- if (attr)
- return attr;
+ SmallVector<OpFoldResult> getMixedSizes() {
+ SmallVector<int64_t> statics;
- int64_t rank = getMixedSizes().size();
-
- setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, 0));
+ /// Get the static sizes/shape, the value passed to const_shape
+ /// will overide the value in memref shape.
+ if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
+ statics = llvm::to_vector(memrefTy.getShape());
+ if (auto attr = getConstShapeAttr())
+ statics = llvm::to_vector(attr.asArrayRef());
- attr = getConstOffsetsAttr();
- return attr;
+ return getMixedValues(statics, getShape(), getContext());
}
- /// wrapper for matching with OffsetSizeAndStrideOpInterface
- /// If source is IntegerType or `const_shape` is filled,
- /// it will return `const_shape`, such that mixes of `shape`
- /// and `const_shape` will be used to represent the shape of
- /// source operand. They overide static shape from source memref type.
- ArrayRef<int64_t> getStaticSizes() {
- /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
- static llvm::SmallVector<int64_t, 4> emptyShape;
-
- auto attr = getConstShapeAttr();
- if (attr)
- return attr;
-
- if (llvm::isa<IntegerType>(getSourceType()))
- return emptyShape;
-
- auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
- assert(memrefType && "Incorrect use of getStaticSizes");
- return memrefType.getShape();
- }
+ SmallVector<OpFoldResult> getMixedStrides() {
+ SmallVector<int64_t> statics;
- /// wrapper for matching with OffsetSizeAndStrideOpInterface
- /// If source is IntegerType or `const_strides` is filled, it
- /// will return `const_strides`, such that mixes of `strides`
- /// and `const_strides` will be used to represent the strides of
- /// source operand. They overide static strides from source memref type.
- ArrayRef<int64_t> getStaticStrides() {
- /// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
- static llvm::SmallVector<int64_t, 4> emptyStrides;
-
- auto attr = getConstStridesAttr();
- if (attr)
- return attr;
-
- if (llvm::isa<IntegerType>(getSourceType()))
- return emptyStrides;
-
- auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
- assert(memrefType && "Incorrect use of getStaticStrides");
- auto [strides, _] = memrefType.getStridesAndOffset();
- // reuse the storage of ConstStridesAttr since strides from
- // memref is not persistant
- setConstStrides(strides);
- attr = getConstStridesAttr();
- return attr;
- }
+ /// Get the static strides, the value passed to const_strides
+ /// will overide the value in memref.
+ if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
+ statics = memrefTy.getStridesAndOffset().first;
+ if (auto attr = getConstStridesAttr())
+ statics = llvm::to_vector(attr.asArrayRef());
- /// Return the expected rank of each of the`static_offsets`,
- /// `static_shape` and `static_strides` attributes.
- std::array<unsigned, 3> getArrayAttrMaxRanks() {
- unsigned rank;
- if (auto ty = llvm::dyn_cast<MemRefType>(getSourceType())) {
- rank = ty.getRank();
- } else {
- rank = (unsigned)getMixedOffsets().size();
- }
- return {rank, rank, rank};
+ return getMixedValues(statics, getStrides(), getContext());
}
/// Return the number of leading operands before the `offsets`,
@@ -314,15 +269,15 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
}];
let assemblyFormat = [{
- $TensorDesc ``
- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
+ $TensorDesc ``
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` qualified(type($TensorDesc))
}];
let builders = [
- OpBuilder<(ins "Value": $TensorDesc,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ OpBuilder<(ins "Value": $TensorDesc,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -370,7 +325,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
Variadic<Index>: $offsets,
- OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<UnitAttr>: $packed,
OptionalAttr<DenseI64ArrayAttr>: $transpose,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -390,16 +345,16 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
}];
let assemblyFormat = [{
- $TensorDesc ``
- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
+ $TensorDesc ``
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
}];
let builders = [
- OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
+ OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -442,7 +397,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
let arguments = (ins XeGPU_ValueType: $value,
XeGPU_TensorDesc: $TensorDesc,
Variadic<Index>: $offsets,
- OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -458,16 +413,16 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
}];
let assemblyFormat = [{
- $value `,`
- $TensorDesc ``
- custom<OptionalDynamicIndexList>($offsets, $const_offsets)
+ $value `,`
+ $TensorDesc ``
+ custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc))
}];
let builders = [
- OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -635,12 +590,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<16xf16>
```
-
+
Example 2:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
The source operand could be a raw pointer (uint64_t).
- Please refer to create_tdesc for the restriction of memref.
+ Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
@@ -676,16 +631,16 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
}];
let assemblyFormat = [{
- $source
+ $source
(`[` $offsets^ `]`)?
prop-dict
- attr-dict `:` type(operands)
+ attr-dict `:` type(operands)
}];
-
+
let builders = [
OpBuilder<(ins "Value": $source,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -723,7 +678,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
vector<16xi1> -> vector<16x8xf32>
```
-
+
Example 3 (SIMT mode):
```mlir
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -732,12 +687,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
vector<16xi1> -> vector<8xf32>
```
-
+
Example 4:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
- for the restriction of memref.
+ for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%offsets = vector.step : vector<16xindex>
@@ -794,14 +749,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
let assemblyFormat = [{
$source
(`[` $offsets^ `]`)? `,`
- $mask prop-dict
+ $mask prop-dict
attr-dict `:` type(operands) `->` type($value)
}];
let builders = [
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
@@ -848,7 +803,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
The dest operand could be a raw pointer (uint64_t).
- Please refer to create_tdesc for the restriction of memref.
+ Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%val = arith.constant dense<0.0> : vector<16xf32>
@@ -901,15 +856,15 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
$value `,`
$dest
(`[` $offsets^ `]`)? `,`
- $mask
- prop-dict
+ $mask
+ prop-dict
attr-dict `:` type(operands)
}];
let builders = [
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
- "xegpu::CachePolicyAttr": $l1_hint,
- "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 33450f3fa229e..7ac885d2ed40f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -265,7 +265,7 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
}
LogicalResult CreateNdDescOp::verify() {
- auto rank = (int64_t)getMixedOffsets().size();
+ int64_t rank = getMixedSizes().size();
bool invalidRank = false;
bool invalidElemTy = false;
@@ -280,6 +280,9 @@ LogicalResult CreateNdDescOp::verify() {
<< " Source: " << srcMemorySpace
<< ", TensorDesc: " << tdescMemorySpace;
+ if (int64_t offsetRank = getMixedOffsets().size())
+ invalidRank |= (offsetRank != rank);
+
// check source type matches the rank if it is a memref.
// It also should have the same ElementType as TensorDesc.
auto memrefTy = dyn_cast<MemRefType>(getSourceType());
@@ -291,7 +294,7 @@ LogicalResult CreateNdDescOp::verify() {
if (llvm::isa<IntegerType>(getSourceType())) {
// strides and shape must present for integer source.
if (getMixedStrides().empty() || getMixedSizes().empty())
- return emitOpError("Expecting strides and shape to be present for "
+ return emitOpError("expecting strides and shape to be present for "
"integer source.");
}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index dff3ffab39ecf..cdf147a9fdd0e 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -52,14 +52,14 @@ func.func @create_nd_tdesc_7(%src: memref<128x128xf32>) {
// -----
func.func @create_nd_tdesc_8(%src: ui64) {
- // expected-error@+1 {{'xegpu.create_nd_tdesc' op Expecting strides and shape to be present for integer source}}
+ // expected-error@+1 {{'xegpu.create_nd_tdesc' op expecting strides and shape to be present for integer source}}
%1 = xegpu.create_nd_tdesc %src : ui64-> !xegpu.tensor_desc<128x128xf32>
return
}
// -----
func.func @create_nd_tdesc_9(%src: ui64) {
- // expected-error@+1 {{expected mixed offsets rank to match mixed sizes rank}}
+ // expected-error@+1 {{expecting strides and shape to be present for integer source}}
%1 = xegpu.create_nd_tdesc %src[0, 0] : ui64-> !xegpu.tensor_desc<128x128xf32>
return
}
@@ -149,7 +149,7 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
}
// -----
-func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
+func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
// expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}}
@@ -418,7 +418,7 @@ func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
// expected-error@+1 {{value elements must match chunk size}}
- xegpu.store %val, %src[%offsets], %mask
+ xegpu.store %val, %src[%offsets], %mask
: vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
return
}
@@ -429,7 +429,7 @@ func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
// expected-error@+1 {{Expecting the dest is a 1D memref or pointer}}
- xegpu.store %val, %src[%offsets], %mask
+ xegpu.store %val, %src[%offsets], %mask
: vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1>
return
}
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 6be2371d4d7b2..67c00f5a9cc2f 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -62,28 +62,28 @@ gpu.func @create_nd_tdesc_7(%src: memref<8x24x32x48x64xf32>) {
}
-// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>)
+// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>)
gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index, %src2: memref<24x32xf32>) {
//CHECK: %[[C:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
-
- // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
%3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-
+
gpu.return
}
-// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index)
+// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index)
gpu.func @test_create_nd_tdesc_8(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
-
- %c1 = arith.constant 1 : index
- // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
%2 = xegpu.create_nd_tdesc %src, shape : [%h, %w], strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
-
+
gpu.return
}
-// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}})
+// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}})
gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
@@ -94,10 +94,10 @@ gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index,
gpu.return
}
-// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}})
-gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
+// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}})
+gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
%c1 = arith.constant 1 : index
- // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref<?x?xf16> -...
[truncated]
|
@@ -60,8 +60,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>: | |||
} | |||
|
|||
|
|||
def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, | |||
AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> { | |||
def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, AttrSizedOperandSegments]> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is op description section aligned with the changes now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change only impacts the implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the description, marking offsets deprecated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
As XeGPU design is moving
offsets
out ofCreateNdDescOp
.OffsetSizeAndStrideOpInterface
is no longer meaningful in such context.