Skip to content

Commit 0ef9ed7

Browse files
committed
Use memref.extract_strided_metadata to compute strides
Signed-off-by: dchigarev <[email protected]>
1 parent beeac48 commit 0ef9ed7

File tree

6 files changed

+14
-35
lines changed

6 files changed

+14
-35
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
105105
auto [strides, offset] = srcTy.getStridesAndOffset();
106106

107107
xegpu::CreateNdDescOp ndDesc;
108-
if (srcTy.hasStaticShape())
108+
if (srcTy.hasStaticShape()) {
109109
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
110-
else {
110+
} else {
111111
// In case of any dynamic shapes, source's shape and strides have to be
112112
// explicitly provided.
113113
SmallVector<Value> sourceDims;
@@ -123,21 +123,8 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
123123
mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
124124
}
125125

126-
// Compute strides in reverse order.
127-
SmallVector<OpFoldResult> mixedStrides;
128-
Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
129-
// Last stride is guaranteed to be static and unit.
130-
mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
131-
for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
132-
accStride =
133-
arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
134-
if (strides[i] == ShapedType::kDynamic)
135-
mixedStrides.push_back(accStride);
136-
else
137-
mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
138-
}
139-
std::reverse(mixedStrides.begin(), mixedStrides.end());
140-
126+
auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
127+
SmallVector<OpFoldResult> mixedStrides(meta.getStrides().begin(), meta.getStrides().end());
141128
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
142129
mixedShapes, mixedStrides);
143130
}

mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
5252
// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
5353
// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
5454
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
55-
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
55+
// CHECK: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
5656
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
57-
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
57+
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
5858
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
5959
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
6060
// CHECK: return %[[VEC]]

mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
5454
// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
5555
// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
5656
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
57-
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
57+
// CHECK: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
5858
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
59-
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
59+
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
6060
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
6161
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
6262

mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
150150
// LOAD-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
151151
// LOAD-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
152152
// LOAD-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
153-
// LOAD-ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
153+
// LOAD-ND: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
154154
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
155155
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]{{.*}}-> vector<8x16xf32>
156156
// LOAD-ND: return %[[VEC]]
@@ -186,7 +186,7 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
186186
// LOAD-ND-LABEL: @load_dynamic_source2(
187187
// LOAD-ND-DAG: %[[C0:.+]] = arith.constant 0 : index
188188
// LOAD-ND-DAG: %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
189-
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
189+
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[DIM]], 8, 16], strides : [%c128, %c16, %c1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
190190
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
191191
// LOAD-ND: return %[[VEC]] : vector<8x16xf32>
192192

mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
9090
// STORE-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
9191
// STORE-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
9292
// STORE-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
93-
// STORE-ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
93+
// STORE-ND: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
9494
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
95-
// STORE-ND-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
95+
// STORE-ND-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
9696
// STORE-ND-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
9797
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
9898

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,10 @@ func.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) {
132132
return
133133
}
134134

135-
// -----
136-
func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
137-
%1 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<16xf16>
138-
// expected-error@+1 {{Offsets rank must match either the source or the TensorDesc rank.}}
139-
%2 = xegpu.load_nd %1[0, 0] : !xegpu.tensor_desc<16xf16> -> vector<16xf16>
140-
return
141-
}
142-
143135
// -----
144136
func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
145137
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
146-
// expected-error@+1 {{Offsets rank must match either the source or the TensorDesc rank.}}
138+
// expected-error@+1 {{Offsets rank cannot be smaller than tensor descriptor rank.}}
147139
xegpu.prefetch_nd %3[0] : !xegpu.tensor_desc<8x16xf16>
148140
return
149141
}
@@ -152,7 +144,7 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
152144
func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
153145
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
154146
%5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
155-
// expected-error@+1 {{Offsets rank must match either the source or the TensorDesc rank.}}
147+
// expected-error@+1 {{Offsets rank cannot be smaller than tensor descriptor rank.}}
156148
xegpu.store_nd %5, %3[%x] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
157149
return
158150
}

0 commit comments

Comments
 (0)