@@ -35,6 +35,26 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
3535 return inBounds;
3636}
3737
38+ // / Generate a runtime check to see if the given indices are in-bounds with
39+ // / respect to the given ranked memref.
40+ Value generateIndicesInBoundsCheck (OpBuilder &builder, Location loc,
41+ Value memref, ValueRange indices) {
42+ auto memrefType = cast<MemRefType>(memref.getType ());
43+ assert (memrefType.getRank () == static_cast <int64_t >(indices.size ()) &&
44+ " rank mismatch" );
45+ Value cond = builder.create <arith::ConstantOp>(
46+ loc, builder.getIntegerAttr (builder.getI1Type (), 1 ));
47+
48+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
49+ for (auto [dim, idx] : llvm::enumerate (indices)) {
50+ Value dimOp = builder.createOrFold <memref::DimOp>(loc, memref, dim);
51+ Value inBounds = generateInBoundsCheck (builder, loc, idx, zero, dimOp);
52+ cond = builder.createOrFold <arith::AndIOp>(loc, cond, inBounds);
53+ }
54+
55+ return cond;
56+ }
57+
3858struct AssumeAlignmentOpInterface
3959 : public RuntimeVerifiableOpInterface::ExternalModel<
4060 AssumeAlignmentOpInterface, AssumeAlignmentOp> {
@@ -230,26 +250,10 @@ struct LoadStoreOpInterface
230250 void generateRuntimeVerification (Operation *op, OpBuilder &builder,
231251 Location loc) const {
232252 auto loadStoreOp = cast<LoadStoreOp>(op);
233-
234- auto memref = loadStoreOp.getMemref ();
235- auto rank = memref.getType ().getRank ();
236- if (rank == 0 ) {
237- return ;
238- }
239- auto indices = loadStoreOp.getIndices ();
240-
241- auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
242- Value assertCond;
243- for (auto i : llvm::seq<int64_t >(0 , rank)) {
244- Value dimOp = builder.createOrFold <memref::DimOp>(loc, memref, i);
245- Value inBounds =
246- generateInBoundsCheck (builder, loc, indices[i], zero, dimOp);
247- assertCond =
248- i > 0 ? builder.createOrFold <arith::AndIOp>(loc, assertCond, inBounds)
249- : inBounds;
250- }
253+ Value cond = generateIndicesInBoundsCheck (
254+ builder, loc, loadStoreOp.getMemref (), loadStoreOp.getIndices ());
251255 builder.create <cf::AssertOp>(
252- loc, assertCond ,
256+ loc, cond ,
253257 RuntimeVerifiableOpInterface::generateErrorMessage (
254258 op, " out-of-bounds access" ));
255259 }
@@ -421,10 +425,13 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
421425 DialectRegistry ®istry) {
422426 registry.addExtension (+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
423427 AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
428+ AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
424429 CastOp::attachInterface<CastOpInterface>(*ctx);
425430 CopyOp::attachInterface<CopyOpInterface>(*ctx);
426431 DimOp::attachInterface<DimOpInterface>(*ctx);
427432 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
433+ GenericAtomicRMWOp::attachInterface<
434+ LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
428435 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
429436 ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
430437 StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
0 commit comments