Skip to content

Commit e334d0b

Browse files
committed
use extractStridedMetadataOp to compute shapes for tdesc
Signed-off-by: dchigarev <[email protected]>
1 parent 392a01f commit e334d0b

File tree

5 files changed

+11
-50
lines changed

5 files changed

+11
-50
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -110,24 +110,10 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
110110
} else {
111111
// In case of any dynamic shapes, source's shape and strides have to be
112112
// explicitly provided.
113-
SmallVector<Value> sourceDims;
114-
unsigned srcRank = srcTy.getRank();
115-
for (unsigned i = 0; i < srcRank; ++i)
116-
sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
117-
118-
SmallVector<OpFoldResult> mixedShapes;
119-
for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
120-
if (shape == ShapedType::kDynamic)
121-
mixedShapes.push_back(sourceDims[idx]);
122-
else
123-
mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
124-
}
125-
126113
auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
127-
SmallVector<OpFoldResult> mixedStrides(meta.getStrides().begin(),
128-
meta.getStrides().end());
129114
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
130-
mixedShapes, mixedStrides);
115+
meta.getConstifiedMixedSizes(),
116+
meta.getConstifiedMixedStrides());
131117
}
132118

133119
return ndDesc;

mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,9 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
4646
// CHECK-LABEL: @load_dynamic_source(
4747
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
4848
// CHECK-SAME: %[[OFFSET:.+]]: index
49-
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
50-
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
51-
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
52-
// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
53-
// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
54-
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
55-
// CHECK: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
49+
// CHECK: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
5650
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
57-
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
51+
// CHECK-SAME: , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
5852
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
5953
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
6054
// CHECK: return %[[VEC]]

mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,9 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
4848
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
4949
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
5050
// CHECK-SAME: %[[OFFSET:.+]]: index
51-
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
52-
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
53-
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
54-
// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
55-
// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
56-
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
57-
// CHECK: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
51+
// CHECK: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
5852
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
59-
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
53+
// CHECK-SAME: , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
6054
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
6155
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
6256

mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,7 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
144144
// LOAD-ND-LABEL: @load_dynamic_source(
145145
// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
146146
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
147-
// LOAD-ND: %[[C2:.+]] = arith.constant 2 : index
148-
// LOAD-ND: %[[C1:.+]] = arith.constant 1 : index
149-
// LOAD-ND: %[[C0:.+]] = arith.constant 0 : index
150-
// LOAD-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
151-
// LOAD-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
152-
// LOAD-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
153-
// LOAD-ND: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
147+
// LOAD-ND: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
154148
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
155149
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]{{.*}}-> vector<8x16xf32>
156150
// LOAD-ND: return %[[VEC]]
@@ -184,9 +178,8 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
184178
}
185179

186180
// LOAD-ND-LABEL: @load_dynamic_source2(
187-
// LOAD-ND-DAG: %[[C0:.+]] = arith.constant 0 : index
188-
// LOAD-ND-DAG: %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
189-
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[DIM]], 8, 16], strides : [%c128, %c16, %c1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
181+
// LOAD-ND-DAG: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata
182+
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[SIZES]]#0, 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
190183
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
191184
// LOAD-ND: return %[[VEC]] : vector<8x16xf32>
192185

mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,9 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
8484
// STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>,
8585
// STORE-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
8686
// STORE-ND-SAME: %[[OFFSET:.+]]: index
87-
// STORE-ND-DAG: %[[C0:.+]] = arith.constant 0 : index
88-
// STORE-ND-DAG: %[[C1:.+]] = arith.constant 1 : index
89-
// STORE-ND-DAG: %[[C2:.+]] = arith.constant 2 : index
90-
// STORE-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
91-
// STORE-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
92-
// STORE-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
93-
// STORE-ND: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
87+
// STORE-ND: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata
9488
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
95-
// STORE-ND-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
89+
// STORE-ND-SAME: , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
9690
// STORE-ND-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
9791
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
9892

0 commit comments

Comments
 (0)