Skip to content

Commit 2147fca

Browse files
committed
separate the ops
1 parent 57b34ae commit 2147fca

File tree

2 files changed

+43
-50
lines changed

2 files changed

+43
-50
lines changed

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

Lines changed: 27 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -303,28 +303,39 @@ struct ConvertExtractStridedMetadata final
303303
Value source = extractStridedMetadataOp.getSource();
304304

305305
MemRefType memrefType = cast<MemRefType>(source.getType());
306-
if (!isMemRefTypeLegalForEmitC(memrefType)) {
306+
if (!isMemRefTypeLegalForEmitC(memrefType))
307307
return rewriter.notifyMatchFailure(
308308
loc, "incompatible memref type for EmitC conversion");
309-
}
310309

311-
Type resultType = convertMemRefType(memrefType, getTypeConverter());
312-
if (!resultType) {
313-
return rewriter.notifyMatchFailure(loc, "cannot convert result type");
314-
}
315-
316-
auto baseptr =
317-
cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType());
318-
auto emitcType = convertMemRefType(baseptr, getTypeConverter());
319-
auto arrT = emitc::ArrayType::get(memrefType.getShape(), emitcType);
320-
auto valVar = rewriter.create<emitc::VariableOp>(
321-
loc, arrT, emitc::OpaqueAttr::get(rewriter.getContext(), ""));
310+
emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
311+
loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
312+
TypedValue<emitc::ArrayType> srcArrayValue =
313+
cast<TypedValue<emitc::ArrayType>>(operands.getSource());
314+
auto createPointerFromEmitcArray = [loc, &rewriter, &zeroIndex,
315+
srcArrayValue]() -> emitc::ApplyOp {
316+
int64_t rank = srcArrayValue.getType().getRank();
317+
llvm::SmallVector<mlir::Value> indices;
318+
for (int i = 0; i < rank; ++i) {
319+
indices.push_back(zeroIndex);
320+
}
321+
322+
emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
323+
loc, srcArrayValue, mlir::ValueRange(indices));
324+
emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
325+
loc,
326+
emitc::PointerType::get(srcArrayValue.getType().getElementType()),
327+
rewriter.getStringAttr("&"), subPtr);
328+
329+
return ptr;
330+
};
331+
332+
emitc::ApplyOp srcPtr = createPointerFromEmitcArray();
322333
auto [strides, offset] = memrefType.getStridesAndOffset();
323334
Value offsetValue = rewriter.create<emitc::ConstantOp>(
324335
loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
325336

326337
SmallVector<Value> results;
327-
results.push_back(valVar);
338+
results.push_back(srcPtr);
328339
results.push_back(offsetValue);
329340

330341
for (unsigned i = 0, e = memrefType.getRank(); i < e; ++i) {
@@ -343,39 +354,6 @@ struct ConvertExtractStridedMetadata final
343354
}
344355
};
345356

346-
struct ConvertReinterpretCastOp
347-
: public OpConversionPattern<memref::ReinterpretCastOp> {
348-
using OpConversionPattern::OpConversionPattern;
349-
350-
LogicalResult
351-
matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
352-
ConversionPatternRewriter &rewriter) const override {
353-
MemRefType srcType = cast<MemRefType>(castOp.getSource().getType());
354-
355-
MemRefType targetMemRefType =
356-
cast<MemRefType>(castOp.getResult().getType());
357-
358-
auto srcInEmitC = convertMemRefType(srcType, getTypeConverter());
359-
auto targetInEmitC =
360-
convertMemRefType(targetMemRefType, getTypeConverter());
361-
if (!srcInEmitC || !targetInEmitC) {
362-
return rewriter.notifyMatchFailure(castOp.getLoc(),
363-
"cannot convert memref type");
364-
}
365-
366-
// Create descriptor.
367-
Location loc = castOp.getLoc();
368-
369-
auto vals = adaptor.getOperands();
370-
371-
auto res =
372-
UnrealizedConversionCastOp::create(rewriter, loc, targetInEmitC, vals)
373-
.getResult(0);
374-
375-
return success();
376-
}
377-
};
378-
379357
} // namespace
380358

381359
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
@@ -409,7 +387,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
409387
void mlir::populateMemRefToEmitCConversionPatterns(
410388
RewritePatternSet &patterns, const TypeConverter &converter) {
411389
patterns.add<ConvertAlloca, ConvertAlloc, ConvertExtractStridedMetadata,
412-
ConvertGlobal, ConvertGetGlobal, ConvertLoad,
413-
ConvertReinterpretCastOp, ConvertStore>(converter,
414-
patterns.getContext());
390+
ConvertGlobal, ConvertGetGlobal, ConvertLoad, ConvertStore>(
391+
converter, patterns.getContext());
415392
}

mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,19 @@ module @globals {
5858
return
5959
}
6060
}
61+
62+
// -----
63+
64+
// CHECK-LABEL: reinterpret_cast
65+
func.func @reinterpret_cast(%arg18: memref<1xi32>) {
66+
// CHECK: %0 = builtin.unrealized_conversion_cast %arg0 : memref<1xi32> to !emitc.array<1xi32>
67+
// CHECK: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index
68+
// CHECK: %2 = emitc.subscript %0[%1] : (!emitc.array<1xi32>, index) -> !emitc.lvalue<i32>
69+
// CHECK: %3 = emitc.apply "&"(%2) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
70+
// CHECK: %4 = "emitc.constant"() <{value = 0 : index}> : () -> index
71+
// CHECK: %5 = "emitc.constant"() <{value = 1 : index}> : () -> index
72+
// CHECK: %6 = "emitc.constant"() <{value = 1 : index}> : () -> index
73+
%base_buffer_485, %offset_486, %sizes_487, %strides_488 = memref.extract_strided_metadata %arg18 : memref<1xi32> -> memref<i32>, index, index, index
74+
return
75+
}
76+

0 commit comments

Comments
 (0)