@@ -35,6 +35,26 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
3535 return inBounds;
3636}
3737
38+ // / Generate a runtime check to see if the given indices are in-bounds with
39+ // / respect to the given ranked memref.
40+ Value generateIndicesInBoundsCheck (OpBuilder &builder, Location loc,
41+ Value memref, ValueRange indices) {
42+ auto memrefType = cast<MemRefType>(memref.getType ());
43+ assert (memrefType.getRank () == static_cast <int64_t >(indices.size ()) &&
44+ " rank mismatch" );
45+ Value cond = builder.create <arith::ConstantOp>(
46+ loc, builder.getIntegerAttr (builder.getI1Type (), 1 ));
47+
48+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
49+ for (auto [dim, idx] : llvm::enumerate (indices)) {
50+ Value dimOp = builder.createOrFold <memref::DimOp>(loc, memref, dim);
51+ Value inBounds = generateInBoundsCheck (builder, loc, idx, zero, dimOp);
52+ cond = builder.createOrFold <arith::AndIOp>(loc, cond, inBounds);
53+ }
54+
55+ return cond;
56+ }
57+
3858struct AssumeAlignmentOpInterface
3959 : public RuntimeVerifiableOpInterface::ExternalModel<
4060 AssumeAlignmentOpInterface, AssumeAlignmentOp> {
@@ -186,26 +206,10 @@ struct LoadStoreOpInterface
186206 void generateRuntimeVerification (Operation *op, OpBuilder &builder,
187207 Location loc) const {
188208 auto loadStoreOp = cast<LoadStoreOp>(op);
189-
190- auto memref = loadStoreOp.getMemref ();
191- auto rank = memref.getType ().getRank ();
192- if (rank == 0 ) {
193- return ;
194- }
195- auto indices = loadStoreOp.getIndices ();
196-
197- auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
198- Value assertCond;
199- for (auto i : llvm::seq<int64_t >(0 , rank)) {
200- Value dimOp = builder.createOrFold <memref::DimOp>(loc, memref, i);
201- Value inBounds =
202- generateInBoundsCheck (builder, loc, indices[i], zero, dimOp);
203- assertCond =
204- i > 0 ? builder.createOrFold <arith::AndIOp>(loc, assertCond, inBounds)
205- : inBounds;
206- }
209+ Value cond = generateIndicesInBoundsCheck (
210+ builder, loc, loadStoreOp.getMemref (), loadStoreOp.getIndices ());
207211 builder.create <cf::AssertOp>(
208- loc, assertCond ,
212+ loc, cond ,
209213 RuntimeVerifiableOpInterface::generateErrorMessage (
210214 op, " out-of-bounds access" ));
211215 }
@@ -377,9 +381,12 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
377381 DialectRegistry ®istry) {
378382 registry.addExtension (+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
379383 AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
384+ AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
380385 CastOp::attachInterface<CastOpInterface>(*ctx);
381386 DimOp::attachInterface<DimOpInterface>(*ctx);
382387 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
388+ GenericAtomicRMWOp::attachInterface<
389+ LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
383390 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
384391 ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
385392 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
0 commit comments