@@ -33,19 +33,18 @@ using namespace mlir::amdgpu;
3333
3434// / This pattern supports lowering of: `vector.maskedload` to `vector.load`
3535// / and `arith.select` if the memref is in buffer address space.
36- static LogicalResult baseInBufferAddrSpace (PatternRewriter &rewriter,
37- vector::MaskedLoadOp maskedOp) {
38- auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase ().getType ());
36+ static LogicalResult hasBufferAddressSpace (Type type) {
37+ auto memRefType = dyn_cast<MemRefType>(type);
3938 if (!memRefType)
40- return rewriter. notifyMatchFailure (maskedOp, " not a memref source " );
39+ return failure ( );
4140
4241 Attribute addrSpace = memRefType.getMemorySpace ();
4342 if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
44- return rewriter. notifyMatchFailure (maskedOp, " no address space " );
43+ return failure ( );
4544
4645 if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue () !=
4746 amdgpu::AddressSpace::FatRawBuffer)
48- return rewriter. notifyMatchFailure (maskedOp, " not in buffer address space " );
47+ return failure ( );
4948
5049 return success ();
5150}
@@ -62,13 +61,25 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
6261 return load;
6362}
6463
65- // / Check if the given value comes from a broadcasted i1 condition.
66- static FailureOr<Value> matchFullMask (OpBuilder &b, Value val) {
64+ // / If the given value is the broadcast of a non-constant scalar, return that
65+ // / scalar, extracting it from length-1 vectors if necessary.
66+ static FailureOr<Value> getFullMask (RewriterBase &rw, Value val) {
67+ while (auto shapeCast = val.getDefiningOp <vector::ShapeCastOp>())
68+ val = shapeCast.getSource ();
69+ auto splatOp = val.getDefiningOp <vector::SplatOp>();
70+ if (splatOp)
71+ return splatOp.getInput ();
6772 auto broadcastOp = val.getDefiningOp <vector::BroadcastOp>();
6873 if (!broadcastOp)
6974 return failure ();
70- if (isa<VectorType>(broadcastOp.getSourceType ()))
71- return failure ();
75+ if (auto sourceVecType = dyn_cast<VectorType>(broadcastOp.getSourceType ())) {
76+ if (sourceVecType.isScalable () || sourceVecType.getNumElements () != 1 )
77+ return failure ();
78+ SmallVector<int64_t > indices (sourceVecType.getRank (), 0 );
79+ Value scalarSource = vector::ExtractOp::create (
80+ rw, broadcastOp.getLoc (), broadcastOp.getSource (), indices);
81+ return scalarSource;
82+ }
7283 return broadcastOp.getSource ();
7384}
7485
@@ -85,14 +96,14 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
8596 if (maskedOp->hasAttr (kMaskedloadNeedsMask ))
8697 return failure ();
8798
88- if (failed (baseInBufferAddrSpace (rewriter, maskedOp))) {
99+ if (failed (hasBufferAddressSpace ( maskedOp. getBase (). getType () ))) {
89100 return failure ();
90101 }
91102
92103 // Check if this is either a full inbounds load or an empty, oob load. If
93104 // so, take the fast path and don't generate an if condition, because we
94105 // know doing the oob load is always safe.
95- if (succeeded (matchFullMask (rewriter, maskedOp.getMask ()))) {
106+ if (succeeded (getFullMask (rewriter, maskedOp.getMask ()))) {
96107 Value load = createVectorLoadForMaskedLoad (rewriter, maskedOp.getLoc (),
97108 maskedOp, /* passthru=*/ true );
98109 rewriter.replaceOp (maskedOp, load);
@@ -176,7 +187,11 @@ struct FullMaskedLoadToConditionalLoad
176187
177188 LogicalResult matchAndRewrite (vector::MaskedLoadOp loadOp,
178189 PatternRewriter &rewriter) const override {
179- FailureOr<Value> maybeCond = matchFullMask (rewriter, loadOp.getMask ());
190+ if (succeeded (hasBufferAddressSpace (loadOp.getBase ().getType ())))
191+ return rewriter.notifyMatchFailure (
192+ loadOp, " buffer loads are handled by a more specialized pattern" );
193+
194+ FailureOr<Value> maybeCond = getFullMask (rewriter, loadOp.getMask ());
180195 if (failed (maybeCond)) {
181196 return failure ();
182197 }
@@ -203,7 +218,16 @@ struct FullMaskedStoreToConditionalStore
203218
204219 LogicalResult matchAndRewrite (vector::MaskedStoreOp storeOp,
205220 PatternRewriter &rewriter) const override {
206- FailureOr<Value> maybeCond = matchFullMask (rewriter, storeOp.getMask ());
221+ // A condition-free implementation of fully masked stores requires
222+ // 1) an accessor for the num_records field on buffer resources/fat pointers
223+ // 2) knowledge that said field will always be set accurately - that is,
224+ // that writes to x < num_records of offset wouldn't trap, which is
225+ // something a pattern user would need to assert or we'd need to prove.
226+ //
227+ // Therefore, conditional stores to buffers still go down this path at
228+ // present.
229+
230+ FailureOr<Value> maybeCond = getFullMask (rewriter, storeOp.getMask ());
207231 if (failed (maybeCond)) {
208232 return failure ();
209233 }
0 commit comments