@@ -1133,9 +1133,9 @@ std::tuple<SmallVector<Value>, Type> getCoordsAndType(PatternRewriter &b,
11331133
11341134// A helper to select the right i4 element if it was supposed to
11351135// be a scalar i4 load.
1136- Value selectDataIf4b (PatternRewriter &b, GlobalLoadOp op, Value loadedVec) {
1137- MemRefType srcType = op. getSource (). getType ();
1138- Type originalLoadedType = op. getResult (). getType ();
1136+ Value selectDataIf4b (Location loc, PatternRewriter &b,
1137+ SmallVector<Value> &coords, MemRefType srcType,
1138+ Type originalLoadedType, Value loadedVec) {
11391139 if (srcType.getElementType ().getIntOrFloatBitWidth () >= 8 ) {
11401140 return loadedVec;
11411141 }
@@ -1150,8 +1150,6 @@ Value selectDataIf4b(PatternRewriter &b, GlobalLoadOp op, Value loadedVec) {
11501150 assert (srcType.getElementType ().getIntOrFloatBitWidth () == 4 &&
11511151 " we only support 4bits in narrow types" );
11521152 assert (isa<VectorType>(loadedVec.getType ()));
1153- Location loc = op.getLoc ();
1154- SmallVector<Value, 5 > coords (op.getSourceCoord ());
11551153 ArrayRef<int64_t > shape = srcType.getShape ();
11561154 Value flatAddress = flattenCoords (b, loc, coords, shape);
11571155 Type coordType = flatAddress.getType ();
@@ -1198,6 +1196,12 @@ struct GlobalLoadRewritePattern : public OpRewritePattern<GlobalLoadOp> {
11981196 source = zeroDMemrefAsOneD (b, source);
11991197 coords.push_back (b.createOrFold <ConstantIndexOp>(loc, 0 ));
12001198 }
1199+ // We need to copy these params here, because the next if might replace
1200+ // "op". So, we can't safely access it after that.
1201+ // TODO: refactor this code
1202+ MemRefType srcType = op.getSource ().getType ();
1203+ Type originalLoadedType = op.getResult ().getType ();
1204+ SmallVector<Value> sourceCoords (op.getSourceCoord ());
12011205
12021206 PatternRewriter::InsertionGuard insertGuard (b);
12031207 if (emitOobChecks && !useBufferOps) {
@@ -1207,11 +1211,12 @@ struct GlobalLoadRewritePattern : public OpRewritePattern<GlobalLoadOp> {
12071211 loc, arith::CmpIPredicate::uge, coords[0 ], numElems);
12081212 cond = b.create <arith::AndIOp>(loc, fallsOffEnd, cond);
12091213 }
1210- auto guard = b.create <scf::IfOp>(loc, loadedType, cond, true , true );
1214+ auto guard =
1215+ b.create <scf::IfOp>(loc, originalLoadedType, cond, true , true );
12111216 b.replaceOp (op, guard);
12121217
12131218 b.setInsertionPointToEnd (guard.getBody (1 ));
1214- Value zeroes = createZeroConstantOp (b, loc, loadedType );
1219+ Value zeroes = createZeroConstantOp (b, loc, originalLoadedType );
12151220 b.create <scf::YieldOp>(loc, zeroes);
12161221 b.setInsertionPointToEnd (guard.getBody (0 ));
12171222 }
@@ -1249,15 +1254,18 @@ struct GlobalLoadRewritePattern : public OpRewritePattern<GlobalLoadOp> {
12491254 else
12501255 loaded = thisLoad;
12511256 });
1252- loaded = selectDataIf4b (b, op, loaded);
1257+ loaded = selectDataIf4b (loc, b, sourceCoords, srcType, originalLoadedType,
1258+ loaded);
12531259 b.replaceOp (op, loaded);
12541260 } else {
12551261 Value loaded;
12561262 if (isa<VectorType>(loadedType))
12571263 loaded = b.create <vector::LoadOp>(loc, loadedType, source, coords);
12581264 else
12591265 loaded = b.create <memref::LoadOp>(loc, loadedType, source, coords);
1260- loaded = selectDataIf4b (b, op, loaded);
1266+
1267+ loaded = selectDataIf4b (loc, b, sourceCoords, srcType, originalLoadedType,
1268+ loaded);
12611269 if (emitOobChecks)
12621270 b.create <scf::YieldOp>(loc, loaded);
12631271 else
0 commit comments