|
17 | 17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | 18 | #include "mlir/IR/Builders.h" |
19 | 19 | #include "mlir/IR/BuiltinTypes.h" |
| 20 | +#include "mlir/IR/Diagnostics.h" |
20 | 21 | #include "mlir/IR/PatternMatch.h" |
21 | 22 | #include "mlir/IR/TypeRange.h" |
22 | 23 | #include "mlir/IR/Value.h" |
23 | 24 | #include "mlir/Transforms/DialectConversion.h" |
24 | 25 | #include <cstdint> |
| 26 | +#include <numeric> |
25 | 27 |
|
26 | 28 | using namespace mlir; |
27 | 29 |
|
@@ -97,6 +99,48 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { |
97 | 99 | return resultTy; |
98 | 100 | } |
99 | 101 |
|
| 102 | +static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType, |
| 103 | + OpBuilder &builder) { |
| 104 | + assert(isMemRefTypeLegalForEmitC(memrefType) && |
| 105 | + "incompatible memref type for EmitC conversion"); |
| 106 | + emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create( |
| 107 | + builder, loc, emitc::SizeTType::get(builder.getContext()), |
| 108 | + builder.getStringAttr("sizeof"), ValueRange{}, |
| 109 | + ArrayAttr::get(builder.getContext(), |
| 110 | + {TypeAttr::get(memrefType.getElementType())})); |
| 111 | + |
| 112 | + IndexType indexType = builder.getIndexType(); |
| 113 | + int64_t numElements = std::accumulate(memrefType.getShape().begin(), |
| 114 | + memrefType.getShape().end(), int64_t{1}, |
| 115 | + std::multiplies<int64_t>()); |
| 116 | + emitc::ConstantOp numElementsValue = emitc::ConstantOp::create( |
| 117 | + builder, loc, indexType, builder.getIndexAttr(numElements)); |
| 118 | + |
| 119 | + Type sizeTType = emitc::SizeTType::get(builder.getContext()); |
| 120 | + emitc::MulOp totalSizeBytes = emitc::MulOp::create( |
| 121 | + builder, loc, sizeTType, elementSize.getResult(0), numElementsValue); |
| 122 | + |
| 123 | + return totalSizeBytes.getResult(); |
| 124 | +} |
| 125 | + |
| 126 | +static emitc::ApplyOp |
| 127 | +createPointerFromEmitcArray(Location loc, OpBuilder &builder, |
| 128 | + TypedValue<emitc::ArrayType> arrayValue) { |
| 129 | + |
| 130 | + emitc::ConstantOp zeroIndex = emitc::ConstantOp::create( |
| 131 | + builder, loc, builder.getIndexType(), builder.getIndexAttr(0)); |
| 132 | + |
| 133 | + emitc::ArrayType arrayType = arrayValue.getType(); |
| 134 | + llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex); |
| 135 | + emitc::SubscriptOp subPtr = |
| 136 | + emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices)); |
| 137 | + emitc::ApplyOp ptr = emitc::ApplyOp::create( |
| 138 | + builder, loc, emitc::PointerType::get(arrayType.getElementType()), |
| 139 | + builder.getStringAttr("&"), subPtr); |
| 140 | + |
| 141 | + return ptr; |
| 142 | +} |
| 143 | + |
100 | 144 | struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { |
101 | 145 | using OpConversionPattern::OpConversionPattern; |
102 | 146 | LogicalResult |
@@ -159,6 +203,47 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { |
159 | 203 | } |
160 | 204 | }; |
161 | 205 |
|
| 206 | +struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> { |
| 207 | + using OpConversionPattern::OpConversionPattern; |
| 208 | + |
| 209 | + LogicalResult |
| 210 | + matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands, |
| 211 | + ConversionPatternRewriter &rewriter) const override { |
| 212 | + Location loc = copyOp.getLoc(); |
| 213 | + MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType()); |
| 214 | + MemRefType targetMemrefType = |
| 215 | + cast<MemRefType>(copyOp.getTarget().getType()); |
| 216 | + |
| 217 | + if (!isMemRefTypeLegalForEmitC(srcMemrefType)) |
| 218 | + return rewriter.notifyMatchFailure( |
| 219 | + loc, "incompatible source memref type for EmitC conversion"); |
| 220 | + |
| 221 | + if (!isMemRefTypeLegalForEmitC(targetMemrefType)) |
| 222 | + return rewriter.notifyMatchFailure( |
| 223 | + loc, "incompatible target memref type for EmitC conversion"); |
| 224 | + |
| 225 | + auto srcArrayValue = |
| 226 | + cast<TypedValue<emitc::ArrayType>>(operands.getSource()); |
| 227 | + emitc::ApplyOp srcPtr = |
| 228 | + createPointerFromEmitcArray(loc, rewriter, srcArrayValue); |
| 229 | + |
| 230 | + auto targetArrayValue = |
| 231 | + cast<TypedValue<emitc::ArrayType>>(operands.getTarget()); |
| 232 | + emitc::ApplyOp targetPtr = |
| 233 | + createPointerFromEmitcArray(loc, rewriter, targetArrayValue); |
| 234 | + |
| 235 | + emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create( |
| 236 | + rewriter, loc, TypeRange{}, "memcpy", |
| 237 | + ValueRange{ |
| 238 | + targetPtr.getResult(), srcPtr.getResult(), |
| 239 | + calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)}); |
| 240 | + |
| 241 | + rewriter.replaceOp(copyOp, memCpyCall.getResults()); |
| 242 | + |
| 243 | + return success(); |
| 244 | + } |
| 245 | +}; |
| 246 | + |
162 | 247 | struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { |
163 | 248 | using OpConversionPattern::OpConversionPattern; |
164 | 249 |
|
@@ -320,6 +405,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { |
320 | 405 |
|
321 | 406 | void mlir::populateMemRefToEmitCConversionPatterns( |
322 | 407 | RewritePatternSet &patterns, const TypeConverter &converter) { |
323 | | - patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal, |
324 | | - ConvertLoad, ConvertStore>(converter, patterns.getContext()); |
| 408 | + patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal, |
| 409 | + ConvertGetGlobal, ConvertLoad, ConvertStore>( |
| 410 | + converter, patterns.getContext()); |
325 | 411 | } |
0 commit comments