@@ -128,6 +128,50 @@ struct CastOpInterface
128128 }
129129};
130130
131+ struct CopyOpInterface
132+ : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
133+ CopyOp> {
134+ void generateRuntimeVerification (Operation *op, OpBuilder &builder,
135+ Location loc) const {
136+ auto copyOp = cast<CopyOp>(op);
137+ BaseMemRefType sourceType = copyOp.getSource ().getType ();
138+ BaseMemRefType targetType = copyOp.getTarget ().getType ();
139+ auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
140+ auto rankedTargetType = dyn_cast<MemRefType>(targetType);
141+
142+ // TODO: Verification for unranked memrefs is not supported yet.
143+ if (!rankedSourceType || !rankedTargetType)
144+ return ;
145+
146+ assert (sourceType.getRank () == targetType.getRank () && " rank mismatch" );
147+ for (int64_t i = 0 , e = sourceType.getRank (); i < e; ++i) {
148+ // Fully static dimensions in both source and target operand are already
149+ // verified by the op verifier.
150+ if (!rankedSourceType.isDynamicDim (i) &&
151+ !rankedTargetType.isDynamicDim (i))
152+ continue ;
153+ auto getDimSize = [&](Value memRef, MemRefType type,
154+ int64_t dim) -> Value {
155+ return type.isDynamicDim (dim)
156+ ? builder.create <DimOp>(loc, memRef, dim).getResult ()
157+ : builder
158+ .create <arith::ConstantIndexOp>(loc,
159+ type.getDimSize (dim))
160+ .getResult ();
161+ };
162+ Value sourceDim = getDimSize (copyOp.getSource (), rankedSourceType, i);
163+ Value targetDim = getDimSize (copyOp.getTarget (), rankedTargetType, i);
164+ Value sameDimSize = builder.create <arith::CmpIOp>(
165+ loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
166+ builder.create <cf::AssertOp>(
167+ loc, sameDimSize,
168+ RuntimeVerifiableOpInterface::generateErrorMessage (
169+ op, " size of " + std::to_string (i) +
170+ " -th source/target dim does not match" ));
171+ }
172+ }
173+ };
174+
131175// / Verifies that the indices on load/store ops are in-bounds of the memref's
132176// / index space: 0 <= index#i < dim#i
133177template <typename LoadStoreOp>
@@ -335,6 +379,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
335379 DialectRegistry ®istry) {
336380 registry.addExtension (+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
337381 CastOp::attachInterface<CastOpInterface>(*ctx);
382+ CopyOp::attachInterface<CopyOpInterface>(*ctx);
338383 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
339384 LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
340385 ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
0 commit comments