|
16 | 16 | #include "mlir/Dialect/EmitC/IR/EmitC.h" |
17 | 17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | 18 | #include "mlir/IR/Builders.h" |
| 19 | +#include "mlir/IR/BuiltinOps.h" |
19 | 20 | #include "mlir/IR/BuiltinTypes.h" |
20 | 21 | #include "mlir/IR/PatternMatch.h" |
21 | 22 | #include "mlir/IR/TypeRange.h" |
22 | 23 | #include "mlir/IR/Value.h" |
| 24 | +#include "mlir/IR/ValueRange.h" |
23 | 25 | #include "mlir/Transforms/DialectConversion.h" |
24 | 26 | #include <cstdint> |
25 | 27 |
|
@@ -288,6 +290,90 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { |
288 | 290 | return success(); |
289 | 291 | } |
290 | 292 | }; |
| 293 | + |
| 294 | +struct ConvertExtractStridedMetadata final |
| 295 | + : public OpConversionPattern<memref::ExtractStridedMetadataOp> { |
| 296 | + using OpConversionPattern::OpConversionPattern; |
| 297 | + |
| 298 | + LogicalResult |
| 299 | + matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, |
| 300 | + OpAdaptor operands, |
| 301 | + ConversionPatternRewriter &rewriter) const override { |
| 302 | + Location loc = extractStridedMetadataOp.getLoc(); |
| 303 | + Value source = extractStridedMetadataOp.getSource(); |
| 304 | + |
| 305 | + MemRefType memrefType = cast<MemRefType>(source.getType()); |
| 306 | + if (!isMemRefTypeLegalForEmitC(memrefType)) { |
| 307 | + return rewriter.notifyMatchFailure( |
| 308 | + loc, "incompatible memref type for EmitC conversion"); |
| 309 | + } |
| 310 | + |
| 311 | + Type resultType = convertMemRefType(memrefType, getTypeConverter()); |
| 312 | + if (!resultType) { |
| 313 | + return rewriter.notifyMatchFailure(loc, "cannot convert result type"); |
| 314 | + } |
| 315 | + |
| 316 | + auto baseptr = |
| 317 | + cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()); |
| 318 | + auto emitcType = convertMemRefType(baseptr, getTypeConverter()); |
| 319 | + |
| 320 | + auto [strides, offset] = memrefType.getStridesAndOffset(); |
| 321 | + Value offsetValue = rewriter.create<emitc::ConstantOp>( |
| 322 | + loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset)); |
| 323 | + |
| 324 | + SmallVector<Value> results; |
| 325 | + results.push_back(extractStridedMetadataOp.getBaseBuffer()); |
| 326 | + results.push_back(offsetValue); |
| 327 | + |
| 328 | + for (unsigned i = 0, e = memrefType.getRank(); i < e; ++i) { |
| 329 | + Value sizeValue = rewriter.create<emitc::ConstantOp>( |
| 330 | + loc, rewriter.getIndexType(), |
| 331 | + rewriter.getIndexAttr(memrefType.getDimSize(i))); |
| 332 | + results.push_back(sizeValue); |
| 333 | + |
| 334 | + Value strideValue = rewriter.create<emitc::ConstantOp>( |
| 335 | + loc, rewriter.getIndexType(), rewriter.getIndexAttr(strides[i])); |
| 336 | + results.push_back(strideValue); |
| 337 | + } |
| 338 | + |
| 339 | + rewriter.replaceOp(extractStridedMetadataOp, results); |
| 340 | + return success(); |
| 341 | + } |
| 342 | +}; |
| 343 | + |
| 344 | +struct ConvertReinterpretCastOp |
| 345 | + : public OpConversionPattern<memref::ReinterpretCastOp> { |
| 346 | + using OpConversionPattern::OpConversionPattern; |
| 347 | + |
| 348 | + LogicalResult |
| 349 | + matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, |
| 350 | + ConversionPatternRewriter &rewriter) const override { |
| 351 | + MemRefType srcType = cast<MemRefType>(castOp.getSource().getType()); |
| 352 | + |
| 353 | + MemRefType targetMemRefType = |
| 354 | + cast<MemRefType>(castOp.getResult().getType()); |
| 355 | + |
| 356 | + auto srcInEmitC = convertMemRefType(srcType, getTypeConverter()); |
| 357 | + auto targetInEmitC = |
| 358 | + convertMemRefType(targetMemRefType, getTypeConverter()); |
| 359 | + if (!srcInEmitC || !targetInEmitC) { |
| 360 | + return rewriter.notifyMatchFailure(castOp.getLoc(), |
| 361 | + "cannot convert memref type"); |
| 362 | + } |
| 363 | + |
| 364 | + // Create descriptor. |
| 365 | + Location loc = castOp.getLoc(); |
| 366 | + |
| 367 | + auto vals = adaptor.getOperands(); |
| 368 | + |
| 369 | + auto res = |
| 370 | + UnrealizedConversionCastOp::create(rewriter, loc, targetInEmitC, vals) |
| 371 | + .getResult(0); |
| 372 | + |
| 373 | + return success(); |
| 374 | + } |
| 375 | +}; |
| 376 | + |
291 | 377 | } // namespace |
292 | 378 |
|
293 | 379 | void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { |
@@ -320,6 +406,8 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { |
320 | 406 |
|
321 | 407 | void mlir::populateMemRefToEmitCConversionPatterns( |
322 | 408 | RewritePatternSet &patterns, const TypeConverter &converter) { |
323 | | - patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal, |
324 | | - ConvertLoad, ConvertStore>(converter, patterns.getContext()); |
| 409 | + patterns.add<ConvertAlloca, ConvertAlloc, ConvertExtractStridedMetadata, |
| 410 | + ConvertGlobal, ConvertGetGlobal, ConvertLoad, |
| 411 | + ConvertReinterpretCastOp, ConvertStore>(converter, |
| 412 | + patterns.getContext()); |
325 | 413 | } |
0 commit comments