Skip to content

Commit 6a9eb1d

Browse files
committed
[mlir][AMDGPU] Improve masked_load(..., broadcast(...), ...) handling
1. Fix the fact that the full masked load pattern (which creates an if statement) could overlap with the buffer load handling pattern, since they didn't have distinct pattern benefits and were relying on order of addition to the pattern set for priority (which isn't reliable). While I was here, add more cases to the broadcast value recognizer - Since this pattern often runs after broadcast lowering, recognize splat vectors. - Recognize broadcasts of unit vectors and convert them to the scalar case by constructing an extract() - Look through shape_cast ops
1 parent a858c90 commit 6a9eb1d

File tree

2 files changed

+98
-14
lines changed

2 files changed

+98
-14
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,18 @@ using namespace mlir::amdgpu;
3333

3434
/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
3535
/// and `arith.select` if the memref is in buffer address space.
36-
static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
37-
vector::MaskedLoadOp maskedOp) {
38-
auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
36+
static LogicalResult hasBufferAddressSpace(Type type) {
37+
auto memRefType = dyn_cast<MemRefType>(type);
3938
if (!memRefType)
40-
return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
39+
return failure();
4140

4241
Attribute addrSpace = memRefType.getMemorySpace();
4342
if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
44-
return rewriter.notifyMatchFailure(maskedOp, "no address space");
43+
return failure();
4544

4645
if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
4746
amdgpu::AddressSpace::FatRawBuffer)
48-
return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
47+
return failure();
4948

5049
return success();
5150
}
@@ -62,13 +61,25 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
6261
return load;
6362
}
6463

65-
/// Check if the given value comes from a broadcasted i1 condition.
66-
static FailureOr<Value> matchFullMask(OpBuilder &b, Value val) {
64+
/// If the given value is the broadcast of a non-constant scalar, return that
65+
/// scalar, extracting it from length-1 vectors if necessary.
66+
static FailureOr<Value> getFullMask(RewriterBase &rw, Value val) {
67+
while (auto shapeCast = val.getDefiningOp<vector::ShapeCastOp>())
68+
val = shapeCast.getSource();
69+
auto splatOp = val.getDefiningOp<vector::SplatOp>();
70+
if (splatOp)
71+
return splatOp.getInput();
6772
auto broadcastOp = val.getDefiningOp<vector::BroadcastOp>();
6873
if (!broadcastOp)
6974
return failure();
70-
if (isa<VectorType>(broadcastOp.getSourceType()))
71-
return failure();
75+
if (auto sourceVecType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
76+
if (sourceVecType.isScalable() || sourceVecType.getNumElements() != 1)
77+
return failure();
78+
SmallVector<int64_t> indices(sourceVecType.getRank(), 0);
79+
Value scalarSource = vector::ExtractOp::create(
80+
rw, broadcastOp.getLoc(), broadcastOp.getSource(), indices);
81+
return scalarSource;
82+
}
7283
return broadcastOp.getSource();
7384
}
7485

@@ -85,14 +96,14 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
8596
if (maskedOp->hasAttr(kMaskedloadNeedsMask))
8697
return failure();
8798

88-
if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
99+
if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) {
89100
return failure();
90101
}
91102

92103
// Check if this is either a full inbounds load or an empty, oob load. If
93104
// so, take the fast path and don't generate an if condition, because we
94105
// know doing the oob load is always safe.
95-
if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) {
106+
if (succeeded(getFullMask(rewriter, maskedOp.getMask()))) {
96107
Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(),
97108
maskedOp, /*passthru=*/true);
98109
rewriter.replaceOp(maskedOp, load);
@@ -176,7 +187,11 @@ struct FullMaskedLoadToConditionalLoad
176187

177188
LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
178189
PatternRewriter &rewriter) const override {
179-
FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
190+
if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType())))
191+
return rewriter.notifyMatchFailure(
192+
loadOp, "buffer loads are handled by a more specialized pattern");
193+
194+
FailureOr<Value> maybeCond = getFullMask(rewriter, loadOp.getMask());
180195
if (failed(maybeCond)) {
181196
return failure();
182197
}
@@ -203,7 +218,16 @@ struct FullMaskedStoreToConditionalStore
203218

204219
LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
205220
PatternRewriter &rewriter) const override {
206-
FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
221+
// A condition-free implementation of fully masked stores requires
222+
// 1) an accessor for the num_records field on buffer resources/fat pointers
223+
// 2) knowledge that said field will always be set accurately - that is,
224+
// that writes to x < num_records of offset wouldn't trap, which is
225+
// something a pattern user would need to assert or we'd need to prove.
226+
//
227+
// Therefore, conditional stores to buffers still go down this path at
228+
// present.
229+
230+
FailureOr<Value> maybeCond = getFullMask(rewriter, storeOp.getMask());
207231
if (failed(maybeCond)) {
208232
return failure();
209233
}

mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,63 @@ func.func @full_mask_maskedstore_to_store(%arg0: memref<8x8xf16>, %arg1: index,
167167
// CHECK-NOT: vector.maskedstore
168168
// CHECK: scf.if %[[PRED]]
169169
// CHECK: vector.store
170+
171+
// -----
172+
173+
// CHECK-LABEL: func.func @full_select_maskedload_to_load_splat
174+
// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>,
175+
// CHECK-SAME: %[[IDX:.+]]: index,
176+
// CHECK-SAME: %[[PRED:.+]]: i1,
177+
// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>)
178+
func.func @full_select_maskedload_to_load_splat(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> {
179+
%0 = vector.splat %arg2 : vector<4xi1>
180+
%1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16>
181+
return %1 : vector<4xf16>
182+
}
183+
// CHECK-NOT: vector.maskedload
184+
// CHECK: scf.if %[[PRED]]
185+
// CHECK: %[[LOAD:.+]] = vector.load
186+
// CHECK: scf.yield %[[LOAD]]
187+
// CHECK: else
188+
// CHECK: scf.yield %[[PASSTHRU]]
189+
190+
// -----
191+
192+
// CHECK-LABEL: func.func @full_select_maskedload_to_load_unit_vector_pred
193+
// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>,
194+
// CHECK-SAME: %[[IDX:.+]]: index,
195+
// CHECK-SAME: %[[PREDVEC:.+]]: vector<1xi1>,
196+
// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>)
197+
func.func @full_select_maskedload_to_load_unit_vector_pred(%arg0: memref<8x8xf16>, %arg1: index, %arg2: vector<1xi1>, %arg3: vector<4xf16>) -> vector<4xf16> {
198+
%0 = vector.broadcast %arg2 : vector<1xi1> to vector<4xi1>
199+
%1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16>
200+
return %1 : vector<4xf16>
201+
}
202+
// CHECK-NOT: vector.maskedload
203+
// CHECK: %[[PRED:.+]] = vector.extract %[[PREDVEC]][0] : i1 from vector<1xi1>
204+
// CHECK: scf.if %[[PRED]]
205+
// CHECK: %[[LOAD:.+]] = vector.load
206+
// CHECK: scf.yield %[[LOAD]]
207+
// CHECK: else
208+
// CHECK: scf.yield %[[PASSTHRU]]
209+
210+
// -----
211+
212+
// CHECK-LABEL: func.func @full_select_maskedload_to_load_2d_unit_vector_pred
213+
// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>,
214+
// CHECK-SAME: %[[IDX:.+]]: index,
215+
// CHECK-SAME: %[[PREDVEC:.+]]: vector<1x1xi1>,
216+
// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>)
217+
func.func @full_select_maskedload_to_load_2d_unit_vector_pred(%arg0: memref<8x8xf16>, %arg1: index, %arg2: vector<1x1xi1>, %arg3: vector<4xf16>) -> vector<4xf16> {
218+
%0 = vector.broadcast %arg2 : vector<1x1xi1> to vector<2x2xi1>
219+
%1 = vector.shape_cast %0 : vector<2x2xi1> to vector<4xi1>
220+
%2 = vector.maskedload %arg0[%arg1, %arg1], %1, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16>
221+
return %2 : vector<4xf16>
222+
}
223+
// CHECK-NOT: vector.maskedload
224+
// CHECK: %[[PRED:.+]] = vector.extract %[[PREDVEC]][0, 0] : i1 from vector<1x1xi1>
225+
// CHECK: scf.if %[[PRED]]
226+
// CHECK: %[[LOAD:.+]] = vector.load
227+
// CHECK: scf.yield %[[LOAD]]
228+
// CHECK: else
229+
// CHECK: scf.yield %[[PASSTHRU]]

0 commit comments

Comments
 (0)