@@ -35,6 +35,26 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
3535 return inBounds;
3636}
3737
38+ // / Generate a runtime check to see if the given indices are in-bounds with
39+ // / respect to the given ranked memref.
40+ Value generateIndicesInBoundsCheck (OpBuilder &builder, Location loc,
41+ Value memref, ValueRange indices) {
42+ auto memrefType = cast<MemRefType>(memref.getType ());
43+ assert (memrefType.getRank () == static_cast <int64_t >(indices.size ()) &&
44+ " rank mismatch" );
45+ Value cond = builder.create <arith::ConstantOp>(
46+ loc, builder.getIntegerAttr (builder.getI1Type (), 1 ));
47+
48+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
49+ for (auto [dim, idx] : llvm::enumerate (indices)) {
50+ Value dimOp = builder.createOrFold <memref::DimOp>(loc, memref, dim);
51+ Value inBounds = generateInBoundsCheck (builder, loc, idx, zero, dimOp);
52+ cond = builder.createOrFold <arith::AndIOp>(loc, cond, inBounds);
53+ }
54+
55+ return cond;
56+ }
57+
3858struct AssumeAlignmentOpInterface
3959 : public RuntimeVerifiableOpInterface::ExternalModel<
4060 AssumeAlignmentOpInterface, AssumeAlignmentOp> {
@@ -230,26 +250,10 @@ struct LoadStoreOpInterface
230250 void generateRuntimeVerification (Operation *op, OpBuilder &builder,
231251 Location loc) const {
232252 auto loadStoreOp = cast<LoadStoreOp>(op);
233-
234- auto memref = loadStoreOp.getMemref ();
235- auto rank = memref.getType ().getRank ();
236- if (rank == 0 ) {
237- return ;
238- }
239- auto indices = loadStoreOp.getIndices ();
240-
241- auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
242- Value assertCond;
243- for (auto i : llvm::seq<int64_t >(0 , rank)) {
244- Value dimOp = builder.createOrFold <memref::DimOp>(loc, memref, i);
245- Value inBounds =
246- generateInBoundsCheck (builder, loc, indices[i], zero, dimOp);
247- assertCond =
248- i > 0 ? builder.createOrFold <arith::AndIOp>(loc, assertCond, inBounds)
249- : inBounds;
250- }
253+ Value cond = generateIndicesInBoundsCheck (
254+ builder, loc, loadStoreOp.getMemref (), loadStoreOp.getIndices ());
251255 builder.create <cf::AssertOp>(
252- loc, assertCond ,
256+ loc, cond ,
253257 RuntimeVerifiableOpInterface::generateErrorMessage (
254258 op, " out-of-bounds access" ));
255259 }
@@ -426,10 +430,13 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
426430 DialectRegistry ®istry) {
427431 registry.addExtension (+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
428432 AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
433+ AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
429434 CastOp::attachInterface<CastOpInterface>(*ctx);
430435 CopyOp::attachInterface<CopyOpInterface>(*ctx);
431436 DimOp::attachInterface<DimOpInterface>(*ctx);
432437 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
438+ GenericAtomicRMWOp::attachInterface<
439+ LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
433440 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
434441 ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
435442 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
0 commit comments