Skip to content

Commit 98fe647

Browse files
committed
Add check for out-of-bound on loads
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 3b0022c commit 98fe647

File tree

2 files changed

+159
-40
lines changed

2 files changed

+159
-40
lines changed

test/Triton/Intel/FuseReshape/fuse-reshape.mlir

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,17 @@ tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1:
2525
// CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32
2626
// CHECK: [[MUL2:%.*]] = arith.muli %c2_i32, [[TRUNC]] : i32
2727
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c1_i32 : i32
28-
2928
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [[[ADD1]], %c1024_i64], [%c4_i64, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>>
30-
// CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>>
31-
// CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
29+
// CHECK: [[TRUNC:%.*]] = arith.trunci %c1_i64 : i64 to i32
30+
// CHECK: [[COND:%.*]] = arith.cmpi ult, [[ADD2]], [[TRUNC]] : i32
31+
// CHECK: [[IF_RES:%.*]] = scf.if [[COND]] -> (tensor<32x256xbf16>) {
32+
// CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>>
33+
// CHECK: scf.yield [[LOAD_B]] : tensor<32x256xbf16>
34+
// CHECK: } else {
35+
// CHECK: [[ZERO:%.*]] = arith.constant dense<0.000000e+00> : tensor<32x256xbf16>
36+
// CHECK: scf.yield [[ZERO]] : tensor<32x256xbf16>
37+
// CHECK: }
38+
// CHECK: tt.dot {{.*}}, [[IF_RES]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
3239

3340
// -----
3441

@@ -46,7 +53,7 @@ tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1:
4653
%0 = tt.make_tensor_ptr %arg1, [%c512_i64, %c1024_i64, %c32_i64], [%c1024_i64, %c1_i64, %c512_i64], [%c32_i32, %c32_i32, %c0_i32] {order = array<i32: 2, 0, 1>} : <tensor<1x256x32xbf16>>
4754
%res:2 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<256x256xf32>, i32) : i32 {
4855
%1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>>
49-
%3 = tt.load %0 {boundaryCheck = array<i32: 2, 1>} : !tt.ptr<tensor<1x256x32xbf16>>
56+
%3 = tt.load %0 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<1x256x32xbf16>>
5057
%2 = tt.reshape %3 : tensor<1x256x32xbf16> -> tensor<256x32xbf16>
5158
%4 = tt.dot %2, %1, %arg4, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
5259
%5 = arith.addi %arg5, %c32_i32 : i32
@@ -64,7 +71,7 @@ tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1:
6471
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32
6572
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1024_i64, [[ADD1]]], [%c1_i64, %c512_i64], [%c32_i32, [[ADD2]]] {order = array<i32: 0, 1>} : <tensor<256x32xbf16>>
6673
// CHECK: scf.for
67-
// CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 1, 0>} : !tt.ptr<tensor<256x32xbf16>>
74+
// CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256x32xbf16>>
6875
// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
6976

7077
// -----
@@ -121,17 +128,26 @@ tt.func public @fuseLoadWithReshape3(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16
121128
}
122129
// CHECK-LABEL: fuseLoadWithReshape3
123130
// CHECK-NOT: tt.reshape
131+
// CHECK: [[EXT_M:%.*]] = arith.extsi %arg3 : i32 to i64
124132
// CHECK: [[DIV:%.*]] = arith.divui %c1_i64, %17 : i64
125133
// CHECK: [[MUL1:%.*]] = arith.muli %c1_i64, [[DIV]] : i64
126134
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %15 : i64
127135
// CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32
128136
// CHECK: [[MUL2:%.*]] = arith.muli %c0_i32, [[TRUNC]] : i32
129137
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %14 : i32
130138
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg0, [[[ADD1]], %16], [%17, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf32>>
139+
// CHECK: [[TRUNC:%.*]] = arith.trunci [[EXT_M]] : i64 to i32
140+
// CHECK: [[COND:%.*]] = arith.cmpi ult, [[ADD2]], [[TRUNC]] : i32
131141
// CHECK: scf.for {{.*}} = %c0_i32 to {{.*}} step %c32_i32 iter_args([[ARG:%.*]] = [[PTR]]
132-
// CHECK: [[LOAD_A:%.*]] = tt.load [[ARG]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf32>>
133-
// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32>
134-
// CHECK: tt.advance [[ARG]], [%c0_i32, %c32_i32] : <tensor<256x32xf32>>
142+
// CHECK: [[IF_RES:%.*]] = scf.if [[COND]] -> (tensor<256x32xf32>) {
143+
// CHECK: [[LOAD_A:%.*]] = tt.load [[ARG]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf32>>
144+
// CHECK: scf.yield [[LOAD_A]] : tensor<256x32xf32>
145+
// CHECK: } else {
146+
// CHECK: [[ZERO:%.*]] = arith.constant dense<0.000000e+00> : tensor<256x32xf32>
147+
// CHECK: scf.yield [[ZERO]] : tensor<256x32xf32>
148+
// CHECK: }
149+
// CHECK: tt.dot [[IF_RES]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32>
150+
// CHECK: tt.advance [[ARG]], [%c0_i32, %c32_i32] : <tensor<256x32xf32>>
135151

136152
// -----
137153

@@ -152,15 +168,15 @@ tt.func public @fuseLoadWithTrans4(%arg0: i32, %arg1: !tt.ptr<f16>, %arg2: !tt.p
152168
%11 = tt.load %10 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16>>
153169
%res1:1 = scf.for %arg3 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg4 = %arg0) -> (i32) : i32 {
154170
%adv = tt.advance %9, [%arg4, %c0_i32] : <tensor<1x32x64xf16>>
155-
%load = tt.load %adv {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x32x64xf16>>
171+
%load = tt.load %adv : !tt.ptr<tensor<1x32x64xf16>>
156172
%reshape = tt.reshape %load : tensor<1x32x64xf16> -> tensor<32x64xf16>
157173
%dot = tt.dot %11, %reshape, %cst, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
158174
%add = arith.addi %arg4, %c32_i32 : i32
159175
scf.yield %add : i32
160176
}
161177
%res2:1 = scf.for %arg3 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg4 = %arg0) -> (i32) : i32 {
162178
%adv = tt.advance %9, [%arg4, %c0_i32] : <tensor<1x32x64xf16>>
163-
%load = tt.load %adv {boundaryCheck = array<i32: 2, 1>} : !tt.ptr<tensor<1x32x64xf16>>
179+
%load = tt.load %adv : !tt.ptr<tensor<1x32x64xf16>>
164180
%reshape = tt.reshape %load : tensor<1x32x64xf16> -> tensor<32x64xf16>
165181
%dot = tt.dot %11, %reshape, %cst, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
166182
%add = arith.addi %arg4, %c32_i32 : i32
@@ -187,11 +203,11 @@ tt.func public @fuseLoadWithTrans4(%arg0: i32, %arg1: !tt.ptr<f16>, %arg2: !tt.p
187203
// CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr %arg2, [[[ADD12]], %c64_i64], [%c64_i64, %c1_i64], [[[ADD22]], %c2_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16>>
188204
// CHECK: scf.for
189205
// CHECK: [[ADV:%.*]] = tt.advance [[PTR2]], {{.*}} : <tensor<32x64xf16>>
190-
// CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x64xf16>>
206+
// CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV]] : !tt.ptr<tensor<32x64xf16>>
191207
// CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
192208
// CHECK: scf.yield
193209
// CHECK: scf.for
194210
// CHECK: [[ADV:%.*]] = tt.advance [[PTR1]], {{.*}} : <tensor<32x64xf16>>
195-
// CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV]] {boundaryCheck = array<i32: 1, 0>} : !tt.ptr<tensor<32x64xf16>>
211+
// CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV]] : !tt.ptr<tensor<32x64xf16>>
196212
// CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
197213
// CHECK: scf.yield

third_party/intel/lib/Dialect/Triton/Transforms/FuseReshape.cpp

Lines changed: 131 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
#include "intel/include/Utils/DefUseChain.h"
33
#include "intel/include/Utils/Utility.h"
44
#include "mlir/Dialect/Arith/IR/Arith.h"
5+
#include "mlir/Dialect/SCF/IR/SCF.h"
56
#include "mlir/IR/BuiltinAttributes.h"
67
#include "mlir/IR/BuiltinTypes.h"
8+
#include "mlir/IR/IRMapping.h"
79
#include "mlir/IR/Verifier.h"
810
#include "triton/Dialect/Triton/IR/Dialect.h"
911
#include "llvm/ADT/STLExtras.h"
@@ -22,6 +24,59 @@ namespace mlir::triton::intel {
2224

2325
namespace {
2426

27+
scf::IfOp createIfBlock(OpBuilder &builder, Location loc, arith::CmpIOp condOp,
28+
tt::LoadOp loadOp) {
29+
assert(isa<RankedTensorType>(loadOp.getType()) &&
30+
"Unexpected load result type");
31+
32+
auto tensorType = cast<RankedTensorType>(loadOp.getType());
33+
assert(tensorType.getShape().size() == 2);
34+
Type elemType = tensorType.getElementType();
35+
36+
builder.setInsertionPointAfter(loadOp);
37+
auto ifOp = builder.create<scf::IfOp>(loc, tensorType, condOp, true, true);
38+
loadOp->moveBefore(ifOp.thenBlock(), ifOp.thenBlock()->end());
39+
builder.setInsertionPointAfter(loadOp);
40+
builder.create<scf::YieldOp>(loc, loadOp->getResult(0));
41+
42+
builder.setInsertionPointToStart(ifOp.elseBlock());
43+
tt::PaddingOption padding = (!loadOp.getPadding())
44+
? tt::PaddingOption::PAD_ZERO
45+
: *loadOp.getPadding();
46+
DenseElementsAttr denseAttr = nullptr;
47+
switch (padding) {
48+
case tt::PaddingOption::PAD_ZERO: {
49+
denseAttr = DenseElementsAttr::get(cast<ShapedType>(tensorType),
50+
builder.getZeroAttr(elemType));
51+
} break;
52+
case tt::PaddingOption::PAD_NAN: {
53+
assert(elemType.isF128() && "Expecting a floating point type");
54+
auto NaNVal =
55+
APFloat::getNaN(cast<FloatType>(elemType).getFloatSemantics());
56+
denseAttr = DenseElementsAttr::get(cast<ShapedType>(tensorType),
57+
builder.getFloatAttr(elemType, NaNVal));
58+
} break;
59+
default:
60+
llvm_unreachable("Unhandled padding kind");
61+
}
62+
assert(denseAttr && "Expecting a valid attribute");
63+
64+
Value other = builder.create<arith::ConstantOp>(loc, tensorType, denseAttr);
65+
builder.create<scf::YieldOp>(loc, other);
66+
return ifOp;
67+
}
68+
69+
scf::IfOp createCheckedLoad(OpBuilder &builder, arith::CmpIOp condOp,
70+
tt::LoadOp loadOp) {
71+
scf::IfOp ifOp = createIfBlock(builder, loadOp.getLoc(), condOp, loadOp);
72+
loadOp->replaceUsesWithIf(ifOp, [&](OpOperand &operand) {
73+
if (auto yieldOp = dyn_cast<scf::YieldOp>(operand.getOwner()))
74+
return yieldOp->getParentOp() != ifOp;
75+
return true;
76+
});
77+
return ifOp;
78+
};
79+
2580
// Transform:
2681
// %one = arith.constant 1 : i64
2782
// %ptr = make_tensor_ptr %q_view, [%q, %q_23, %q_24],
@@ -87,7 +142,7 @@ class FuseReshape {
87142

88143
LLVM_DEBUG(llvm::dbgs() << "[Before fusion]:\n" << manager << "\n");
89144

90-
// Fuse tt.LoadOp->tt.reshapeOp operations.
145+
// Fuse tt.LoadOp->tt.ReshapeOp operations.
91146
fuse(manager.getChains());
92147

93148
// Remove operations that are no longer used.
@@ -238,14 +293,57 @@ class FuseReshape {
238293

239294
LLVM_DEBUG(llvm::dbgs() << "newMakeTensorPtrOp:\n " << ptr << "\n");
240295

241-
// Adjust the boundary check on the load operation.
242-
ArrayRef<int> boundaryCheck = loadOp.getBoundaryCheck();
243-
assert(boundaryCheck.size() == 2 && "Expecting a 2-dim load");
244-
loadOp.setBoundaryCheck({boundaryCheck[0] - 1, boundaryCheck[1] - 1});
245-
246296
// Propagate the new ptr through the def-use chain.
247-
propagateToUsers(ptr, chain);
297+
IRMapping mapping;
298+
propagateToUsers(ptr, chain, mapping);
248299
cleanUp.insert(makeTensorPtrOp);
300+
301+
// We have collapsed 2 dimensions into one, therefore we might have to
302+
// materialize the boundary check for the new collapsed dimension. There
303+
// are 2 possibilities:
304+
// a) if the load checks only the innermost dimension, we are ok because
305+
// we haven't collapsed that dimension
306+
// b) if the load check the new outermost dimension the boundary check
307+
// on the load is not sufficient and we have to materialize the
308+
// correct boundary check. Example:
309+
// OLD PTR NEW PTR
310+
// shape: [20, 10, 5] -> [210, 5]
311+
// strides: [50, 5, 1] -> [ 5, 1]
312+
//
313+
// Consider a load offset of [1, 11, 1], this access is clearly
314+
// out-of-bound in dim 1 (11 > 10). However, the new offset is not
315+
// longer out-of-bound (5 < 210).
316+
auto newLoadOp =
317+
cast<tt::LoadOp>(mapping.lookup(static_cast<Operation *>(loadOp)));
318+
ArrayRef<int> boundaryCheck = newLoadOp.getBoundaryCheck();
319+
switch (boundaryCheck.size()) {
320+
case 0:
321+
break;
322+
case 1:
323+
// intentional fall-through
324+
case 2: {
325+
SmallVector<int> newBoundaryCheck{boundaryCheck[0] - 1};
326+
if (boundaryCheck.size() == 2)
327+
newBoundaryCheck.push_back(boundaryCheck[1] - 1);
328+
newLoadOp.setBoundaryCheck({newBoundaryCheck});
329+
330+
if (llvm::any_of(newBoundaryCheck, [&](unsigned boundIdx) {
331+
return boundIdx == newOutermostDimIdx;
332+
})) {
333+
Value lhs = newOffsets[newOutermostDimIdx];
334+
Value rhs = shapes[newOutermostDimIdx + 1];
335+
auto cmpOp = builder.create<arith::CmpIOp>(
336+
loc, arith::CmpIPredicate::ult, lhs,
337+
builder.create<arith::TruncIOp>(loc, lhs.getType(), rhs));
338+
createCheckedLoad(builder, cmpOp, newLoadOp);
339+
}
340+
} break;
341+
default:
342+
// Note: while selecting candidates, we already ensured that the original
343+
// load's boundary check doesn't check dim zero. So its max rank should
344+
// be 2.
345+
assert(boundaryCheck.size() != 3 && "Unexpected boundary check rank");
346+
}
249347
}
250348

251349
// Candidate is of the form:
@@ -276,7 +374,8 @@ class FuseReshape {
276374
return false;
277375
}
278376

279-
// Check whether \p reshapeOp is used by a `dotOp` (directly or indirectly).
377+
// Check whether \p reshapeOp is used by a `dotOp` (directly or
378+
// indirectly).
280379
auto usedByDotOp = [](tt::ReshapeOp reshapeOp) {
281380
if (!reshapeOp->hasOneUse())
282381
return false;
@@ -347,10 +446,10 @@ class FuseReshape {
347446
[](int val) { return val == 0; });
348447
}
349448

350-
// Prune chains that cannot be handled during fusion. For example, operations
351-
// in the def-use chain should have a single user, except in special
352-
// circumstances (e.g. the root operation of a chain might have more than one
353-
// user).
449+
// Prune chains that cannot be handled during fusion. For example,
450+
// operations in the def-use chain should have a single user, except in
451+
// special circumstances (e.g. the root operation of a chain might have more
452+
// than one user).
354453
void pruneInvalid(DefUseChains &chains) const {
355454
assert(!chains.empty() && "Expecting at least one candidate chain");
356455

@@ -368,11 +467,10 @@ class FuseReshape {
368467

369468
// Determine whether all operations in the given def-use chain have a single
370469
// user.
371-
// Note: we allow an operation in the def-use chain to have an additional user
372-
// if the operation is in a for loop, and the additional user is the loop
373-
// yield operation, provided that the result yielded is not used after the
374-
// loop.
375-
// Example:
470+
// Note: we allow an operation in the def-use chain to have an additional
471+
// user if the operation is in a for loop, and the additional user is the
472+
// loop yield operation, provided that the result yielded is not used after
473+
// the loop. Example:
376474
// make_tensor_ptr -> advance -> load (OK)
377475
// make_tensor_ptr -> for init_arg -> advance -> load (OK)
378476
// -> yield (OK)
@@ -461,7 +559,8 @@ class FuseReshape {
461559
}
462560

463561
// Propagate \p newVal to operations in the given def-use chain.
464-
void propagateToUsers(Value newVal, const DefUseChain &chain) {
562+
void propagateToUsers(Value newVal, const DefUseChain &chain,
563+
IRMapping &mapping) {
465564
auto start = cast<tt::MakeTensorPtrOp>(chain.getStart());
466565
Operation *end = chain.getEnd();
467566
auto it = llvm::find_if(start->getUsers(), [&](Operation *user) {
@@ -470,22 +569,22 @@ class FuseReshape {
470569
assert(it != start->getUsers().end() && "Expecting valid iterator");
471570

472571
Operation *nextOp = *it;
473-
propagateToUser(newVal, start.getResult(), nextOp, end);
572+
propagateToUser(newVal, start.getResult(), nextOp, end, mapping);
474573
}
475574

476575
// Propagate \p newVal to users of \p origOp.
477576
void propagateToUsers(Value newVal, Value origVal, Operation *origOp,
478-
Operation *sentinel) {
577+
Operation *sentinel, IRMapping &mapping) {
479578
assert(origOp && sentinel && "Expecting valid operations");
480579
const SmallVector<Operation *> users(origOp->getUsers());
481580
for (Operation *user : users)
482-
propagateToUser(newVal, origVal, user, sentinel);
581+
propagateToUser(newVal, origVal, user, sentinel, mapping);
483582
}
484583

485584
// If \p user is not \p sentinel, propagate \p newVal to \p user. Otherwise
486585
// terminate the propagation.
487586
void propagateToUser(Value newVal, Value origVal, Operation *user,
488-
Operation *sentinel) {
587+
Operation *sentinel, IRMapping &mapping) {
489588
assert(user && sentinel && "Expecting valid operations");
490589
assert(llvm::is_contained(origVal.getUsers(), user) && "Invalid usage");
491590

@@ -515,11 +614,13 @@ class FuseReshape {
515614
SmallVector<Value> newOffsets(advanceOp.getOffsets().drop_front());
516615
auto newAdvanceOp = rewriter.create<tt::AdvanceOp>(loc, newVal.getType(),
517616
newVal, newOffsets);
617+
mapping.map(static_cast<Operation *>(advanceOp),
618+
static_cast<Operation *>(newAdvanceOp));
518619
LLVM_DEBUG(llvm::dbgs().indent(2)
519620
<< "newAdvanceOp: " << newAdvanceOp << "\n");
520621
cleanUp.insert(advanceOp);
521622
return propagateToUsers(newAdvanceOp, advanceOp.getResult(), advanceOp,
522-
sentinel);
623+
sentinel, mapping);
523624
}
524625

525626
if (auto loadOp = dyn_cast<tt::LoadOp>(user)) {
@@ -529,10 +630,12 @@ class FuseReshape {
529630
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
530631
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
531632
newLoadOp->setAttrs(loadOp->getAttrs());
532-
633+
mapping.map(static_cast<Operation *>(loadOp),
634+
static_cast<Operation *>(newLoadOp));
533635
LLVM_DEBUG(llvm::dbgs().indent(2) << "newLoadOp: " << newLoadOp << "\n");
534636
cleanUp.insert(loadOp);
535-
return propagateToUsers(newLoadOp, loadOp.getResult(), loadOp, sentinel);
637+
return propagateToUsers(newLoadOp, loadOp.getResult(), loadOp, sentinel,
638+
mapping);
536639
}
537640

538641
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
@@ -553,13 +656,13 @@ class FuseReshape {
553656
}
554657

555658
if (auto forOp = dyn_cast<scf::ForOp>(user))
556-
return propagateToLoop(newVal, origVal, forOp, sentinel);
659+
return propagateToLoop(newVal, origVal, forOp, sentinel, mapping);
557660

558661
llvm_unreachable("Unexpected kind of user");
559662
}
560663

561664
void propagateToLoop(Value newVal, Value origVal, LoopLikeOpInterface loopOp,
562-
Operation *sentinel) {
665+
Operation *sentinel, IRMapping &mapping) {
563666
assert(sentinel && sentinel != loopOp && "Unexpected sentinel kind");
564667
LLVM_DEBUG({
565668
llvm::dbgs() << "In " << __func__ << "\n";
@@ -573,7 +676,7 @@ class FuseReshape {
573676
rgnInitArg.setType(initArg.get().getType());
574677
const SmallVector<Operation *> users(rgnInitArg.getUsers());
575678
for (Operation *user : users)
576-
propagateToUser(rgnInitArg, rgnInitArg, user, sentinel);
679+
propagateToUser(rgnInitArg, rgnInitArg, user, sentinel, mapping);
577680
}
578681
}
579682
}

0 commit comments

Comments
 (0)