@@ -128,6 +128,50 @@ struct CastOpInterface
128
128
}
129
129
};
130
130
131
+ struct CopyOpInterface
132
+ : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
133
+ CopyOp> {
134
+ void generateRuntimeVerification (Operation *op, OpBuilder &builder,
135
+ Location loc) const {
136
+ auto copyOp = cast<CopyOp>(op);
137
+ BaseMemRefType sourceType = copyOp.getSource ().getType ();
138
+ BaseMemRefType targetType = copyOp.getTarget ().getType ();
139
+ auto rankedSourceType = dyn_cast<MemRefType>(sourceType);
140
+ auto rankedTargetType = dyn_cast<MemRefType>(targetType);
141
+
142
+ // TODO: Verification for unranked memrefs is not supported yet.
143
+ if (!rankedSourceType || !rankedTargetType)
144
+ return ;
145
+
146
+ assert (sourceType.getRank () == targetType.getRank () && " rank mismatch" );
147
+ for (int64_t i = 0 , e = sourceType.getRank (); i < e; ++i) {
148
+ // Fully static dimensions in both source and target operand are already
149
+ // verified by the op verifier.
150
+ if (!rankedSourceType.isDynamicDim (i) &&
151
+ !rankedTargetType.isDynamicDim (i))
152
+ continue ;
153
+ auto getDimSize = [&](Value memRef, MemRefType type,
154
+ int64_t dim) -> Value {
155
+ return type.isDynamicDim (dim)
156
+ ? builder.create <DimOp>(loc, memRef, dim).getResult ()
157
+ : builder
158
+ .create <arith::ConstantIndexOp>(loc,
159
+ type.getDimSize (dim))
160
+ .getResult ();
161
+ };
162
+ Value sourceDim = getDimSize (copyOp.getSource (), rankedSourceType, i);
163
+ Value targetDim = getDimSize (copyOp.getTarget (), rankedTargetType, i);
164
+ Value sameDimSize = builder.create <arith::CmpIOp>(
165
+ loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
166
+ builder.create <cf::AssertOp>(
167
+ loc, sameDimSize,
168
+ RuntimeVerifiableOpInterface::generateErrorMessage (
169
+ op, " size of " + std::to_string (i) +
170
+ " -th source/target dim does not match" ));
171
+ }
172
+ }
173
+ };
174
+
131
175
// / Verifies that the indices on load/store ops are in-bounds of the memref's
132
176
// / index space: 0 <= index#i < dim#i
133
177
template <typename LoadStoreOp>
@@ -335,6 +379,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
335
379
DialectRegistry ®istry) {
336
380
registry.addExtension (+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
337
381
CastOp::attachInterface<CastOpInterface>(*ctx);
382
+ CopyOp::attachInterface<CopyOpInterface>(*ctx);
338
383
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
339
384
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
340
385
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
0 commit comments