@@ -303,28 +303,39 @@ struct ConvertExtractStridedMetadata final
303303 Value source = extractStridedMetadataOp.getSource ();
304304
305305 MemRefType memrefType = cast<MemRefType>(source.getType ());
306- if (!isMemRefTypeLegalForEmitC (memrefType)) {
306+ if (!isMemRefTypeLegalForEmitC (memrefType))
307307 return rewriter.notifyMatchFailure (
308308 loc, " incompatible memref type for EmitC conversion" );
309- }
310309
311- Type resultType = convertMemRefType (memrefType, getTypeConverter ());
312- if (!resultType) {
313- return rewriter.notifyMatchFailure (loc, " cannot convert result type" );
314- }
315-
316- auto baseptr =
317- cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer ().getType ());
318- auto emitcType = convertMemRefType (baseptr, getTypeConverter ());
319- auto arrT = emitc::ArrayType::get (memrefType.getShape (), emitcType);
320- auto valVar = rewriter.create <emitc::VariableOp>(
321- loc, arrT, emitc::OpaqueAttr::get (rewriter.getContext (), " " ));
310+ emitc::ConstantOp zeroIndex = rewriter.create <emitc::ConstantOp>(
311+ loc, rewriter.getIndexType (), rewriter.getIndexAttr (0 ));
312+ TypedValue<emitc::ArrayType> srcArrayValue =
313+ cast<TypedValue<emitc::ArrayType>>(operands.getSource ());
314+ auto createPointerFromEmitcArray = [loc, &rewriter, &zeroIndex,
315+ srcArrayValue]() -> emitc::ApplyOp {
316+ int64_t rank = srcArrayValue.getType ().getRank ();
317+ llvm::SmallVector<mlir::Value> indices;
318+ for (int i = 0 ; i < rank; ++i) {
319+ indices.push_back (zeroIndex);
320+ }
321+
322+ emitc::SubscriptOp subPtr = rewriter.create <emitc::SubscriptOp>(
323+ loc, srcArrayValue, mlir::ValueRange (indices));
324+ emitc::ApplyOp ptr = rewriter.create <emitc::ApplyOp>(
325+ loc,
326+ emitc::PointerType::get (srcArrayValue.getType ().getElementType ()),
327+ rewriter.getStringAttr (" &" ), subPtr);
328+
329+ return ptr;
330+ };
331+
332+ emitc::ApplyOp srcPtr = createPointerFromEmitcArray ();
322333 auto [strides, offset] = memrefType.getStridesAndOffset ();
323334 Value offsetValue = rewriter.create <emitc::ConstantOp>(
324335 loc, rewriter.getIndexType (), rewriter.getIndexAttr (offset));
325336
326337 SmallVector<Value> results;
327- results.push_back (valVar );
338+ results.push_back (srcPtr );
328339 results.push_back (offsetValue);
329340
330341 for (unsigned i = 0 , e = memrefType.getRank (); i < e; ++i) {
@@ -343,39 +354,6 @@ struct ConvertExtractStridedMetadata final
343354 }
344355};
345356
346- struct ConvertReinterpretCastOp
347- : public OpConversionPattern<memref::ReinterpretCastOp> {
348- using OpConversionPattern::OpConversionPattern;
349-
350- LogicalResult
351- matchAndRewrite (memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
352- ConversionPatternRewriter &rewriter) const override {
353- MemRefType srcType = cast<MemRefType>(castOp.getSource ().getType ());
354-
355- MemRefType targetMemRefType =
356- cast<MemRefType>(castOp.getResult ().getType ());
357-
358- auto srcInEmitC = convertMemRefType (srcType, getTypeConverter ());
359- auto targetInEmitC =
360- convertMemRefType (targetMemRefType, getTypeConverter ());
361- if (!srcInEmitC || !targetInEmitC) {
362- return rewriter.notifyMatchFailure (castOp.getLoc (),
363- " cannot convert memref type" );
364- }
365-
366- // Create descriptor.
367- Location loc = castOp.getLoc ();
368-
369- auto vals = adaptor.getOperands ();
370-
371- auto res =
372- UnrealizedConversionCastOp::create (rewriter, loc, targetInEmitC, vals)
373- .getResult (0 );
374-
375- return success ();
376- }
377- };
378-
379357} // namespace
380358
381359void mlir::populateMemRefToEmitCTypeConversion (TypeConverter &typeConverter) {
@@ -409,7 +387,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
409387void mlir::populateMemRefToEmitCConversionPatterns (
410388 RewritePatternSet &patterns, const TypeConverter &converter) {
411389 patterns.add <ConvertAlloca, ConvertAlloc, ConvertExtractStridedMetadata,
412- ConvertGlobal, ConvertGetGlobal, ConvertLoad,
413- ConvertReinterpretCastOp, ConvertStore>(converter,
414- patterns.getContext ());
390+ ConvertGlobal, ConvertGetGlobal, ConvertLoad, ConvertStore>(
391+ converter, patterns.getContext ());
415392}
0 commit comments