@@ -33,19 +33,18 @@ using namespace mlir::amdgpu;
3333
3434// / This pattern supports lowering of: `vector.maskedload` to `vector.load`
3535// / and `arith.select` if the memref is in buffer address space.
36- static LogicalResult baseInBufferAddrSpace (PatternRewriter &rewriter,
37- vector::MaskedLoadOp maskedOp) {
38- auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase ().getType ());
36+ static LogicalResult hasBufferAddressSpace (Type type) {
37+ auto memRefType = dyn_cast<MemRefType>(type);
3938 if (!memRefType)
40- return rewriter. notifyMatchFailure (maskedOp, " not a memref source " );
39+ return failure ( );
4140
4241 Attribute addrSpace = memRefType.getMemorySpace ();
4342 if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
44- return rewriter. notifyMatchFailure (maskedOp, " no address space " );
43+ return failure ( );
4544
4645 if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue () !=
4746 amdgpu::AddressSpace::FatRawBuffer)
48- return rewriter. notifyMatchFailure (maskedOp, " not in buffer address space " );
47+ return failure ( );
4948
5049 return success ();
5150}
@@ -83,10 +82,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
8382 LogicalResult matchAndRewrite (vector::MaskedLoadOp maskedOp,
8483 PatternRewriter &rewriter) const override {
8584 if (maskedOp->hasAttr (kMaskedloadNeedsMask ))
86- return failure ( );
85+ return rewriter. notifyMatchFailure (maskedOp, " already rewritten " );
8786
88- if (failed (baseInBufferAddrSpace (rewriter, maskedOp))) {
89- return failure ();
87+ if (failed (hasBufferAddressSpace (maskedOp.getBase ().getType ()))) {
88+ return rewriter.notifyMatchFailure (
89+ maskedOp, " isn't a load from a fat buffer resource" );
9090 }
9191
9292 // Check if this is either a full inbounds load or an empty, oob load. If
@@ -176,9 +176,14 @@ struct FullMaskedLoadToConditionalLoad
176176
177177 LogicalResult matchAndRewrite (vector::MaskedLoadOp loadOp,
178178 PatternRewriter &rewriter) const override {
179+ if (succeeded (hasBufferAddressSpace (loadOp.getBase ().getType ())))
180+ return rewriter.notifyMatchFailure (
181+ loadOp, " buffer loads are handled by a more specialized pattern" );
182+
179183 FailureOr<Value> maybeCond = matchFullMask (rewriter, loadOp.getMask ());
180184 if (failed (maybeCond)) {
181- return failure ();
185+ return rewriter.notifyMatchFailure (loadOp,
186+ " isn't loading a broadcasted scalar" );
182187 }
183188
184189 Value cond = maybeCond.value ();
@@ -203,6 +208,15 @@ struct FullMaskedStoreToConditionalStore
203208
204209 LogicalResult matchAndRewrite (vector::MaskedStoreOp storeOp,
205210 PatternRewriter &rewriter) const override {
211+ // A condition-free implementation of fully masked stores requires
212+ // 1) an accessor for the num_records field on buffer resources/fat pointers
213+ // 2) knowledge that said field will always be set accurately - that is,
214+ // that writes to x < num_records of offset wouldn't trap, which is
215+ // something a pattern user would need to assert or we'd need to prove.
216+ //
217+ // Therefore, conditional stores to buffers still go down this path at
218+ // present.
219+
206220 FailureOr<Value> maybeCond = matchFullMask (rewriter, storeOp.getMask ());
207221 if (failed (maybeCond)) {
208222 return failure ();
0 commit comments