Skip to content

Commit 3709f67

Browse files
committed
address comments 1/2
1 parent a15b8ca commit 3709f67

File tree

3 files changed

+34
-16
lines changed

3 files changed

+34
-16
lines changed

mlir/include/mlir/Transforms/Passes.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -589,9 +589,10 @@ def FuseMemorySpaceCastsIntoConsumers :
589589
Pass<"fuse-memory-space-casts-into-consumers"> {
590590
let summary = "Fuses memory-space cast operations into consumers.";
591591
let description = [{
592-
This pass tries to fuse all possible memory-space cast operations into their consumers.
593-
It does this by looking for `FuseMemorySpaceCastConsumerOpInterface`
594-
operations, and invoking the interface methods to perform the fusion.
592+
This pass tries to iteratively fuse all possible memory-space cast
593+
operations into their consumers. It does this by looking for
594+
`FuseMemorySpaceCastConsumerOpInterface` operations, and invoking the
595+
interface methods to perform the fusion.
595596

596597
Example:
597598

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1732,7 +1732,7 @@ TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
17321732

17331733
bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
17341734
PtrLikeTypeInterface src) {
1735-
return isa<MemRefType>(tgt) &&
1735+
return isa<BaseMemRefType>(tgt) &&
17361736
tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
17371737
}
17381738

mlir/test/Transforms/test-fuse-casts-into-consumers.mlir

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,23 @@ func.func @load_store_unfoldable(%arg0: memref<?xf32, 1>, %arg1: index) {
3232
return
3333
}
3434

35+
// CHECK-LABEL: func.func @cast(
36+
// CHECK-SAME: %[[ARG0:.*]]: memref<2xf32, 1>,
37+
// CHECK-SAME: %[[ARG1:.*]]: memref<*xf32, 1>) -> (memref<*xf32>, memref<3x2xf32>) {
38+
// CHECK: %[[VAL_0:.*]] = memref.cast %[[ARG0]] : memref<2xf32, 1> to memref<*xf32, 1>
39+
// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<*xf32, 1> to memref<*xf32>
40+
// CHECK: %[[VAL_2:.*]] = memref.cast %[[ARG1]] : memref<*xf32, 1> to memref<3x2xf32, 1>
41+
// CHECK: %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<3x2xf32, 1> to memref<3x2xf32>
42+
// CHECK: return %[[VAL_1]], %[[VAL_3]] : memref<*xf32>, memref<3x2xf32>
43+
// CHECK: }
44+
func.func @cast(%arg0: memref<2xf32, 1>, %arg1: memref<*xf32, 1>) -> (memref<*xf32>, memref<3x2xf32>) {
45+
%memspacecast = memref.memory_space_cast %arg0 : memref<2xf32, 1> to memref<2xf32>
46+
%1 = memref.cast %memspacecast : memref<2xf32> to memref<*xf32>
47+
%memspacecast_1 = memref.memory_space_cast %arg1 : memref<*xf32, 1> to memref<*xf32>
48+
%2 = memref.cast %memspacecast_1 : memref<*xf32> to memref<3x2xf32>
49+
return %1, %2 : memref<*xf32>, memref<3x2xf32>
50+
}
51+
3552
// CHECK-LABEL: func.func @view(
3653
// CHECK-SAME: %[[ARG0:.*]]: memref<?xi8, 1>,
3754
// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<?x?xi8> {
@@ -63,8 +80,8 @@ func.func @subview(%arg0: memref<?x?xf32, 1>, %arg1: index) -> memref<8x2xf32, s
6380
// CHECK-LABEL: func.func @reinterpret_cast(
6481
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
6582
// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> {
66-
// CHECK: %[[VAL_0:.*]] = arith.constant 10 : index
67-
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
83+
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 10 : index
84+
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
6885
// CHECK: %[[VAL_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[VAL_1]]], sizes: [10, %[[VAL_0]]], strides: {{\[}}%[[VAL_0]], 1] : memref<?xf32, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>, 1>
6986
// CHECK: %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<10x?xf32, strided<[?, 1], offset: ?>, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>>
7087
// CHECK: return %[[VAL_3]] : memref<10x?xf32, strided<[?, 1], offset: ?>>
@@ -155,8 +172,8 @@ func.func @assume_alignment(%arg0: memref<?xf32, 1>) -> memref<?xf32> {
155172
// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf32, 1>,
156173
// CHECK-SAME: %[[ARG1:.*]]: index,
157174
// CHECK-SAME: %[[ARG2:.*]]: f32) -> memref<16xf32> {
158-
// CHECK: %[[VAL_0:.*]] = arith.constant 4 : index
159-
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
175+
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 4 : index
176+
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
160177
// CHECK: %[[VAL_2:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1>
161178
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1>
162179
// CHECK: %[[VAL_4:.*]] = memref.memory_space_cast %[[VAL_3]] : memref<16xf32, 1> to memref<16xf32>
@@ -225,8 +242,8 @@ func.func @vector_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
225242
// CHECK-LABEL: func.func @masked_load_store(
226243
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
227244
// CHECK-SAME: %[[ARG1:.*]]: index) {
228-
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
229-
// CHECK: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
245+
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
246+
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
230247
// CHECK: %[[VAL_2:.*]] = vector.maskedload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
231248
// CHECK: vector.maskedstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
232249
// CHECK: return
@@ -243,10 +260,10 @@ func.func @masked_load_store(%arg0: memref<?xf32, 1>, %arg1: index) {
243260
// CHECK-LABEL: func.func @gather_scatter(
244261
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
245262
// CHECK-SAME: %[[ARG1:.*]]: index) {
246-
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
247-
// CHECK: %[[VAL_1:.*]] = arith.constant dense<true> : vector<4xi1>
248-
// CHECK: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
249-
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
263+
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
264+
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<true> : vector<4xi1>
265+
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
266+
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
250267
// CHECK: %[[VAL_4:.*]] = vector.gather %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
251268
// CHECK: vector.scatter %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_4]] : memref<?xf32, 1>, vector<4xindex>, vector<4xi1>, vector<4xf32>
252269
// CHECK: return
@@ -265,8 +282,8 @@ func.func @gather_scatter(%arg0: memref<?xf32, 1>, %arg1: index) {
265282
// CHECK-LABEL: func.func @expandload_compressstore(
266283
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32, 1>,
267284
// CHECK-SAME: %[[ARG1:.*]]: index) {
268-
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
269-
// CHECK: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
285+
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
286+
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
270287
// CHECK: %[[VAL_2:.*]] = vector.expandload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32> into vector<4xf32>
271288
// CHECK: vector.compressstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref<?xf32, 1>, vector<4xi1>, vector<4xf32>
272289
// CHECK: return

0 commit comments

Comments
 (0)