Skip to content

Commit 6dbaa08

Browse files
committed
[mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases
This patch adds additional checks to the hoisting logic to prevent hoisting of `vector.transfer_read`/`vector.transfer_write` pairs when the underlying `memref` has users that introduce aliases via operations implementing `ViewLikeOpInterface`. Note: This may conservatively block some valid hoisting opportunities and could impact performance. However, as demonstrated by the included tests, the current behavior is too permissive and can lead to incorrect transformations. If this change prevents hoisting in cases that are provably safe, please share a minimal repro — I’d be happy to explore ways to relax the check.
1 parent 8b3e345 commit 6dbaa08

File tree

2 files changed

+154
-1
lines changed

2 files changed

+154
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
303303
// 1. indices, vector type and permutation map are the same (i.e., the
304304
// transfer_read/transfer_write ops are matching),
305305
// 2. source operands for transfer.{read|write} do not originate from
306-
// Ops implementing ViewLikeOpInterface.
306+
// nor have users that are Ops implementing ViewLikeOpInterface.
307307
// 3. no other operations in the loop access the same memref except
308308
// for transfer_read/transfer_write accessing statically disjoint
309309
// slices.
@@ -312,14 +312,27 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
312312
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
313313
return WalkResult::advance();
314314

315+
// Check 2. for xfer_read
315316
auto *source = transferRead.getBase().getDefiningOp();
316317
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
317318
return WalkResult::advance();
318319

320+
auto base = transferRead.getBase();
321+
for (auto *user : base.getUsers())
322+
if (isa_and_nonnull<ViewLikeOpInterface>(user))
323+
return WalkResult::advance();
324+
325+
// Check 2. for xfer_wrire
319326
source = transferWrite.getBase().getDefiningOp();
320327
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
321328
return WalkResult::advance();
322329

330+
base = transferWrite.getBase();
331+
for (auto *user : base.getUsers())
332+
if (isa_and_nonnull<ViewLikeOpInterface>(user))
333+
return WalkResult::advance();
334+
335+
// Check 1. + 3.
323336
// TODO: may want to memoize this information for performance but it
324337
// likely gets invalidated often.
325338
DominanceInfo dom(loop);

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,145 @@
11
// RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
22

3+
///----------------------------------------------------------------------------------------
4+
/// Tests for vector.transfer_read + vector.transfer_write pairs
5+
///
6+
/// * Indices are static
7+
/// * Single loop
8+
///----------------------------------------------------------------------------------------
9+
10+
// The most basic example - hoisting is safe.
11+
12+
// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair(
13+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
14+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
15+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
16+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) {
17+
func.func @hoist_basic_vector_xfer_pair(
18+
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
19+
%c0 = arith.constant 0 : index
20+
%pad = arith.constant 0.0 : f32
21+
22+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
23+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
24+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
25+
// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
26+
// CHECK: %[[VAL_6:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
27+
// CHECK: scf.yield %[[VAL_6]] : vector<1xf32>
28+
// CHECK: }
29+
// CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
30+
scf.for %i = %lb to %ub step %step {
31+
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
32+
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
33+
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
34+
}
35+
return
36+
}
37+
38+
module attributes {transform.with_named_sequence} {
39+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
40+
%0 = transform.structured.match ops{["func.func"]} in %arg1
41+
: (!transform.any_op) -> !transform.any_op
42+
transform.structured.hoist_redundant_vector_transfers %0
43+
: (!transform.any_op) -> !transform.any_op
44+
transform.yield
45+
}
46+
}
47+
48+
// -----
49+
50+
// Similar as the example above, but hoisting is no longer safe. That's due to
51+
// an extra xfer_write inside the loop.
52+
53+
// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
54+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
55+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
56+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
57+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
58+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
59+
func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
60+
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
61+
%c0 = arith.constant 0 : index
62+
%pad = arith.constant 0.0 : f32
63+
64+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
65+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
66+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
67+
// CHECK: vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
68+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
69+
// CHECK: %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
70+
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
71+
// CHECK: }
72+
73+
scf.for %i = %lb to %ub step %step {
74+
vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
75+
76+
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
77+
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
78+
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
79+
}
80+
return
81+
}
82+
83+
module attributes {transform.with_named_sequence} {
84+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
85+
%0 = transform.structured.match ops{["func.func"]} in %arg1
86+
: (!transform.any_op) -> !transform.any_op
87+
transform.structured.hoist_redundant_vector_transfers %0
88+
: (!transform.any_op) -> !transform.any_op
89+
transform.yield
90+
}
91+
}
92+
93+
// -----
94+
95+
// Similar as the example above, but hoisting is no longer safe. That's due to
96+
// an extra xfer_write into _an alias_ of the %mem Op that is used by the
97+
// original xfer pair.
98+
99+
// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
100+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
101+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
102+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
103+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
104+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
105+
func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
106+
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
107+
%c0 = arith.constant 0 : index
108+
%pad = arith.constant 0.0 : f32
109+
110+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
111+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
112+
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [1, 1] [1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
113+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
114+
// CHECK: vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
115+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
116+
// CHECK: %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
117+
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
118+
// CHECK: }
119+
120+
%sv = memref.subview %mem[0, 0][1, 1][1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
121+
scf.for %i = %lb to %ub step %step {
122+
vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
123+
124+
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
125+
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
126+
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
127+
}
128+
return
129+
}
130+
131+
module attributes {transform.with_named_sequence} {
132+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
133+
%0 = transform.structured.match ops{["func.func"]} in %arg1
134+
: (!transform.any_op) -> !transform.any_op
135+
transform.structured.hoist_redundant_vector_transfers %0
136+
: (!transform.any_op) -> !transform.any_op
137+
transform.yield
138+
}
139+
}
140+
141+
// -----
142+
3143
///----------------------------------------------------------------------------------------
4144
/// Tests for vector.transfer_read + vector.transfer_write pairs
5145
///

0 commit comments

Comments
 (0)