|
21 | 21 | #include "mlir/IR/TypeRange.h"
|
22 | 22 | #include "mlir/IR/Value.h"
|
23 | 23 | #include "mlir/Transforms/DialectConversion.h"
|
| 24 | +#include "llvm/Support/FormatVariadic.h" |
24 | 25 | #include <cstdint>
|
| 26 | +#include <string> |
25 | 27 |
|
26 | 28 | using namespace mlir;
|
27 | 29 |
|
@@ -269,6 +271,85 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
|
269 | 271 | }
|
270 | 272 | };
|
271 | 273 |
|
| 274 | +struct ConvertReinterpretCastOp final |
| 275 | + : public OpConversionPattern<memref::ReinterpretCastOp> { |
| 276 | + using OpConversionPattern::OpConversionPattern; |
| 277 | + |
| 278 | + LogicalResult |
| 279 | + matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, |
| 280 | + ConversionPatternRewriter &rewriter) const override { |
| 281 | + |
| 282 | + MemRefType srcType = cast<MemRefType>(castOp.getSource().getType()); |
| 283 | + |
| 284 | + MemRefType targetMemRefType = |
| 285 | + cast<MemRefType>(castOp.getResult().getType()); |
| 286 | + |
| 287 | + auto srcInEmitC = convertMemRefType(srcType, getTypeConverter()); |
| 288 | + auto targetInEmitC = |
| 289 | + convertMemRefType(targetMemRefType, getTypeConverter()); |
| 290 | + if (!srcInEmitC || !targetInEmitC) { |
| 291 | + return rewriter.notifyMatchFailure(castOp.getLoc(), |
| 292 | + "cannot convert memref type"); |
| 293 | + } |
| 294 | + Location loc = castOp.getLoc(); |
| 295 | + |
| 296 | + auto srcArrayValue = |
| 297 | + cast<TypedValue<emitc::ArrayType>>(adaptor.getSource()); |
| 298 | + |
| 299 | + emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>( |
| 300 | + loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); |
| 301 | + |
| 302 | + auto createPointerFromEmitcArray = |
| 303 | + [loc, &rewriter, &zeroIndex]( |
| 304 | + mlir::TypedValue<emitc::ArrayType> arrayValue) -> emitc::ApplyOp { |
| 305 | + int64_t rank = arrayValue.getType().getRank(); |
| 306 | + llvm::SmallVector<mlir::Value> indices; |
| 307 | + for (int i = 0; i < rank; ++i) { |
| 308 | + indices.push_back(zeroIndex); |
| 309 | + } |
| 310 | + |
| 311 | + emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>( |
| 312 | + loc, arrayValue, mlir::ValueRange(indices)); |
| 313 | + emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>( |
| 314 | + loc, emitc::PointerType::get(arrayValue.getType().getElementType()), |
| 315 | + rewriter.getStringAttr("&"), subPtr); |
| 316 | + |
| 317 | + return ptr; |
| 318 | + }; |
| 319 | + auto [strides, offset] = targetMemRefType.getStridesAndOffset(); |
| 320 | + // Value offsetValue = rewriter.create<emitc::ConstantOp>( |
| 321 | + // loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset)); |
| 322 | + |
| 323 | + auto srcPtr = createPointerFromEmitcArray(srcArrayValue); |
| 324 | + // emitc::PointerType targetPointerType = |
| 325 | + // emitc::PointerType::get(srcArrayValue.getType().getElementType()); |
| 326 | + |
| 327 | + auto dimensions = targetMemRefType.getShape(); |
| 328 | + std::string reinterpretCastName = llvm::formatv( |
| 329 | + "reinterpret_cast<{0}(*)", srcArrayValue.getType().getElementType()); |
| 330 | + std::string dimensionsStr; |
| 331 | + for (auto dim : dimensions) { |
| 332 | + dimensionsStr += llvm::formatv("[{0}]", dim); |
| 333 | + } |
| 334 | + reinterpretCastName += llvm::formatv("{0}>", dimensionsStr); |
| 335 | + reinterpretCastName += ">"; |
| 336 | + |
| 337 | + reinterpretCastName += llvm::formatv("{0}", srcPtr->getResult(0)); |
| 338 | + |
| 339 | + std::string outputStr = llvm::formatv( |
| 340 | + "{0}(*){1}", srcArrayValue.getType().getElementType(), dimensionsStr); |
| 341 | + auto outputType = emitc::PointerType::get( |
| 342 | + emitc::OpaqueType::get(rewriter.getContext(), outputStr)); |
| 343 | + |
| 344 | + emitc::ConstantOp reinterpretOp = rewriter.create<emitc::ConstantOp>( |
| 345 | + loc, outputType, |
| 346 | + emitc::OpaqueAttr::get(rewriter.getContext(), reinterpretCastName)); |
| 347 | + |
| 348 | + rewriter.replaceOp(castOp, reinterpretOp.getResult()); |
| 349 | + return success(); |
| 350 | + } |
| 351 | +}; |
| 352 | + |
272 | 353 | struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
|
273 | 354 | using OpConversionPattern::OpConversionPattern;
|
274 | 355 |
|
@@ -321,5 +402,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
|
321 | 402 | void mlir::populateMemRefToEmitCConversionPatterns(
|
322 | 403 | RewritePatternSet &patterns, const TypeConverter &converter) {
|
323 | 404 | patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
|
324 |
| - ConvertLoad, ConvertStore>(converter, patterns.getContext()); |
| 405 | + ConvertLoad, ConvertReinterpretCastOp, ConvertStore>( |
| 406 | + converter, patterns.getContext()); |
325 | 407 | }
|
0 commit comments