@@ -182,6 +182,53 @@ struct CastOpInterface
182182 }
183183};
184184
185+ struct CopyOpInterface
186+ : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
187+ CopyOp> {
188+ void generateRuntimeVerification (Operation *op, OpBuilder &builder,
189+ Location loc) const {
190+ auto copyOp = cast<CopyOp>(op);
191+ BaseMemRefType sourceType = copyOp.getSource ().getType ();
192+ BaseMemRefType targetType = copyOp.getTarget ().getType ();
193+ auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
194+ auto rankedTargetType = dyn_cast<MemRefType>(targetType);
195+
196+ // TODO: Verification for unranked memrefs is not supported yet.
197+ if (!rankedSourceType || !rankedTargetType)
198+ return ;
199+
200+ assert (sourceType.getRank () == targetType.getRank () && " rank mismatch" );
201+ for (int64_t i = 0 , e = sourceType.getRank (); i < e; ++i) {
202+ // Fully static dimensions in both source and target operand are already
203+ // verified by the op verifier.
204+ if (!rankedSourceType.isDynamicDim (i) &&
205+ !rankedTargetType.isDynamicDim (i))
206+ continue ;
207+ Value sourceDim;
208+ if (rankedSourceType.isDynamicDim (i)) {
209+ sourceDim = builder.create <DimOp>(loc, copyOp.getSource (), i);
210+ } else {
211+ sourceDim = builder.create <arith::ConstantIndexOp>(
212+ loc, rankedSourceType.getDimSize (i));
213+ }
214+ Value targetDim;
215+ if (rankedTargetType.isDynamicDim (i)) {
216+ targetDim = builder.create <DimOp>(loc, copyOp.getTarget (), i);
217+ } else {
218+ targetDim = builder.create <arith::ConstantIndexOp>(
219+ loc, rankedTargetType.getDimSize (i));
220+ }
221+ Value sameDimSize = builder.create <arith::CmpIOp>(
222+ loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
223+ builder.create <cf::AssertOp>(
224+ loc, sameDimSize,
225+ RuntimeVerifiableOpInterface::generateErrorMessage (
226+ op, " size of " + std::to_string (i) +
227+ " -th source/target dim does not match" ));
228+ }
229+ }
230+ };
231+
185232struct DimOpInterface
186233 : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
187234 DimOp> {
@@ -383,6 +430,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
383430 AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
384431 AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
385432 CastOp::attachInterface<CastOpInterface>(*ctx);
433+ CopyOp::attachInterface<CopyOpInterface>(*ctx);
386434 DimOp::attachInterface<DimOpInterface>(*ctx);
387435 ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
388436 GenericAtomicRMWOp::attachInterface<
0 commit comments