Skip to content

Commit 6752ed3

Browse files
committed
initial work on metadata ops
1 parent 1b9ee0b commit 6752ed3

File tree

1 file changed

+90
-2
lines changed

1 file changed

+90
-2
lines changed

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
#include "mlir/Dialect/EmitC/IR/EmitC.h"
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1818
#include "mlir/IR/Builders.h"
19+
#include "mlir/IR/BuiltinOps.h"
1920
#include "mlir/IR/BuiltinTypes.h"
2021
#include "mlir/IR/PatternMatch.h"
2122
#include "mlir/IR/TypeRange.h"
2223
#include "mlir/IR/Value.h"
24+
#include "mlir/IR/ValueRange.h"
2325
#include "mlir/Transforms/DialectConversion.h"
2426
#include <cstdint>
2527

@@ -288,6 +290,90 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
288290
return success();
289291
}
290292
};
293+
294+
struct ConvertExtractStridedMetadata final
295+
: public OpConversionPattern<memref::ExtractStridedMetadataOp> {
296+
using OpConversionPattern::OpConversionPattern;
297+
298+
LogicalResult
299+
matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
300+
OpAdaptor operands,
301+
ConversionPatternRewriter &rewriter) const override {
302+
Location loc = extractStridedMetadataOp.getLoc();
303+
Value source = extractStridedMetadataOp.getSource();
304+
305+
MemRefType memrefType = cast<MemRefType>(source.getType());
306+
if (!isMemRefTypeLegalForEmitC(memrefType)) {
307+
return rewriter.notifyMatchFailure(
308+
loc, "incompatible memref type for EmitC conversion");
309+
}
310+
311+
Type resultType = convertMemRefType(memrefType, getTypeConverter());
312+
if (!resultType) {
313+
return rewriter.notifyMatchFailure(loc, "cannot convert result type");
314+
}
315+
316+
auto baseptr =
317+
cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType());
318+
auto emitcType = convertMemRefType(baseptr, getTypeConverter());
319+
320+
auto [strides, offset] = memrefType.getStridesAndOffset();
321+
Value offsetValue = rewriter.create<emitc::ConstantOp>(
322+
loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
323+
324+
SmallVector<Value> results;
325+
results.push_back(extractStridedMetadataOp.getBaseBuffer());
326+
results.push_back(offsetValue);
327+
328+
for (unsigned i = 0, e = memrefType.getRank(); i < e; ++i) {
329+
Value sizeValue = rewriter.create<emitc::ConstantOp>(
330+
loc, rewriter.getIndexType(),
331+
rewriter.getIndexAttr(memrefType.getDimSize(i)));
332+
results.push_back(sizeValue);
333+
334+
Value strideValue = rewriter.create<emitc::ConstantOp>(
335+
loc, rewriter.getIndexType(), rewriter.getIndexAttr(strides[i]));
336+
results.push_back(strideValue);
337+
}
338+
339+
rewriter.replaceOp(extractStridedMetadataOp, results);
340+
return success();
341+
}
342+
};
343+
344+
struct ConvertReinterpretCastOp
345+
: public OpConversionPattern<memref::ReinterpretCastOp> {
346+
using OpConversionPattern::OpConversionPattern;
347+
348+
LogicalResult
349+
matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
350+
ConversionPatternRewriter &rewriter) const override {
351+
MemRefType srcType = cast<MemRefType>(castOp.getSource().getType());
352+
353+
MemRefType targetMemRefType =
354+
cast<MemRefType>(castOp.getResult().getType());
355+
356+
auto srcInEmitC = convertMemRefType(srcType, getTypeConverter());
357+
auto targetInEmitC =
358+
convertMemRefType(targetMemRefType, getTypeConverter());
359+
if (!srcInEmitC || !targetInEmitC) {
360+
return rewriter.notifyMatchFailure(castOp.getLoc(),
361+
"cannot convert memref type");
362+
}
363+
364+
// Create descriptor.
365+
Location loc = castOp.getLoc();
366+
367+
auto vals = adaptor.getOperands();
368+
369+
auto res =
370+
UnrealizedConversionCastOp::create(rewriter, loc, targetInEmitC, vals)
371+
.getResult(0);
372+
373+
return success();
374+
}
375+
};
376+
291377
} // namespace
292378

293379
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
@@ -320,6 +406,8 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
320406

321407
void mlir::populateMemRefToEmitCConversionPatterns(
322408
RewritePatternSet &patterns, const TypeConverter &converter) {
323-
patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
324-
ConvertLoad, ConvertStore>(converter, patterns.getContext());
409+
patterns.add<ConvertAlloca, ConvertAlloc, ConvertExtractStridedMetadata,
410+
ConvertGlobal, ConvertGetGlobal, ConvertLoad,
411+
ConvertReinterpretCastOp, ConvertStore>(converter,
412+
patterns.getContext());
325413
}

0 commit comments

Comments
 (0)