Skip to content

Commit a8ac1fb

Browse files
authored
Fix GlobalLoad 4b lowering (#1764)
In some cases GlobalLoad would have been replaced, so we can't access it
1 parent 9fdebc4 commit a8ac1fb

File tree

2 files changed

+37
-9
lines changed

2 files changed

+37
-9
lines changed

mlir/lib/Dialect/Rock/Transforms/SugarToLoops.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,9 +1133,9 @@ std::tuple<SmallVector<Value>, Type> getCoordsAndType(PatternRewriter &b,
11331133

11341134
// A helper to select the right i4 element if it was supposed to
11351135
// be a scalar i4 load.
1136-
Value selectDataIf4b(PatternRewriter &b, GlobalLoadOp op, Value loadedVec) {
1137-
MemRefType srcType = op.getSource().getType();
1138-
Type originalLoadedType = op.getResult().getType();
1136+
Value selectDataIf4b(Location loc, PatternRewriter &b,
1137+
SmallVector<Value> &coords, MemRefType srcType,
1138+
Type originalLoadedType, Value loadedVec) {
11391139
if (srcType.getElementType().getIntOrFloatBitWidth() >= 8) {
11401140
return loadedVec;
11411141
}
@@ -1150,8 +1150,6 @@ Value selectDataIf4b(PatternRewriter &b, GlobalLoadOp op, Value loadedVec) {
11501150
assert(srcType.getElementType().getIntOrFloatBitWidth() == 4 &&
11511151
"we only support 4bits in narrow types");
11521152
assert(isa<VectorType>(loadedVec.getType()));
1153-
Location loc = op.getLoc();
1154-
SmallVector<Value, 5> coords(op.getSourceCoord());
11551153
ArrayRef<int64_t> shape = srcType.getShape();
11561154
Value flatAddress = flattenCoords(b, loc, coords, shape);
11571155
Type coordType = flatAddress.getType();
@@ -1198,6 +1196,12 @@ struct GlobalLoadRewritePattern : public OpRewritePattern<GlobalLoadOp> {
11981196
source = zeroDMemrefAsOneD(b, source);
11991197
coords.push_back(b.createOrFold<ConstantIndexOp>(loc, 0));
12001198
}
1199+
// We need to copy these params here, because the next if might replace
1200+
// "op". So, we can't safely access it after that.
1201+
// TODO: refactor this code
1202+
MemRefType srcType = op.getSource().getType();
1203+
Type originalLoadedType = op.getResult().getType();
1204+
SmallVector<Value> sourceCoords(op.getSourceCoord());
12011205

12021206
PatternRewriter::InsertionGuard insertGuard(b);
12031207
if (emitOobChecks && !useBufferOps) {
@@ -1207,11 +1211,12 @@ struct GlobalLoadRewritePattern : public OpRewritePattern<GlobalLoadOp> {
12071211
loc, arith::CmpIPredicate::uge, coords[0], numElems);
12081212
cond = b.create<arith::AndIOp>(loc, fallsOffEnd, cond);
12091213
}
1210-
auto guard = b.create<scf::IfOp>(loc, loadedType, cond, true, true);
1214+
auto guard =
1215+
b.create<scf::IfOp>(loc, originalLoadedType, cond, true, true);
12111216
b.replaceOp(op, guard);
12121217

12131218
b.setInsertionPointToEnd(guard.getBody(1));
1214-
Value zeroes = createZeroConstantOp(b, loc, loadedType);
1219+
Value zeroes = createZeroConstantOp(b, loc, originalLoadedType);
12151220
b.create<scf::YieldOp>(loc, zeroes);
12161221
b.setInsertionPointToEnd(guard.getBody(0));
12171222
}
@@ -1249,15 +1254,18 @@ struct GlobalLoadRewritePattern : public OpRewritePattern<GlobalLoadOp> {
12491254
else
12501255
loaded = thisLoad;
12511256
});
1252-
loaded = selectDataIf4b(b, op, loaded);
1257+
loaded = selectDataIf4b(loc, b, sourceCoords, srcType, originalLoadedType,
1258+
loaded);
12531259
b.replaceOp(op, loaded);
12541260
} else {
12551261
Value loaded;
12561262
if (isa<VectorType>(loadedType))
12571263
loaded = b.create<vector::LoadOp>(loc, loadedType, source, coords);
12581264
else
12591265
loaded = b.create<memref::LoadOp>(loc, loadedType, source, coords);
1260-
loaded = selectDataIf4b(b, op, loaded);
1266+
1267+
loaded = selectDataIf4b(loc, b, sourceCoords, srcType, originalLoadedType,
1268+
loaded);
12611269
if (emitOobChecks)
12621270
b.create<scf::YieldOp>(loc, loaded);
12631271
else

mlir/test/Dialect/Rock/lowering_global_load_store.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,26 @@ func.func @load_scalar_oob_large(%mem: memref<1073741825xf32>, %valid: i1) -> f3
111111
return %ret : f32
112112
}
113113

114+
// CHECK-LABEL: func.func @load_scalar_oob_large_i4
115+
// CHECK-SAME: (%[[mem:.*]]: memref<1073741825xi4>, %[[valid:.*]]: i1)
116+
func.func @load_scalar_oob_large_i4(%mem: memref<1073741825xi4>, %valid: i1) -> i4 {
117+
%c0 = arith.constant 0 : index
118+
// CHECK: %[[zero:.*]] = arith.constant 0 : i4
119+
// CHECK: %[[cast:.*]] = memref.memory_space_cast %[[mem]]
120+
// CHECK-SAME: #gpu.address_space<global>
121+
// CHECK: %[[ret:.*]] = scf.if %[[valid]] -> (i4)
122+
// CHECK: %[[load:.*]] = vector.load %[[cast]]
123+
// CHECK-SAME: memref<1073741825xi4, #gpu.address_space<global>>, vector<2xi4>
124+
// CHECK: %[[element:.*]] = vector.extractelement %[[load]]
125+
// CHECK: scf.yield %[[element]]
126+
// CHECK: } else {
127+
// CHECK: scf.yield %[[zero]] : i4
128+
%ret = rock.global_load %mem[%c0] if %valid {needs64BitIdx}
129+
: memref<1073741825xi4> -> i4
130+
// CHECK: return %[[ret]]
131+
return %ret : i4
132+
}
133+
114134
// CHECK-LABEL: func.func @store_scalar_in_bounds
115135
// CHECK-SAME: (%[[source:.*]]: memref<5xf32, #gpu.address_space<private>>, %[[mem:.*]]: memref<1x2x3x4x8xf32>)
116136
func.func @store_scalar_in_bounds(%source: memref<5xf32, #gpu.address_space<private>>, %mem: memref<1x2x3x4x8xf32>) {

0 commit comments

Comments
 (0)