@@ -292,13 +292,13 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
292
292
}
293
293
};
294
294
295
- // Lower cir.get_member
295
+ // Lower cir.get_member by aliasing the result memref to the member inside the
296
+ // flattened structure as a byte array. For example
296
297
//
297
298
// clang-format off
298
- //
299
299
// %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
300
300
//
301
- // to something like
301
+ // is lowered to something like
302
302
//
303
303
// %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
304
304
// %c8 = arith.constant 8 : index
@@ -328,37 +328,30 @@ class CIRGetMemberOpLowering
328
328
// concrete datalayout, both datalayouts are the same.
329
329
auto *structLayout = dataLayout.getStructLayout (structType);
330
330
331
- // Get the lowered type: memref<!named_tuple.named_tuple<>>
332
- auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr ().getType ());
333
331
// Alias the memref of struct to a memref of an i8 array of the same size.
334
332
const std::array linearizedSize{
335
333
static_cast <std::int64_t >(dataLayout.getTypeStoreSize (structType))};
336
- auto flattenMemRef = mlir::MemRefType::get (
337
- linearizedSize, mlir::IntegerType::get (memref. getContext (), 8 ));
334
+ auto flattenedMemRef = mlir::MemRefType::get (
335
+ linearizedSize, mlir::IntegerType::get (getContext (), 8 ));
338
336
// Use a special cast because normal memref cast cannot do such an extreme
339
337
// cast.
340
338
auto bytesMemRef = rewriter.create <mlir::named_tuple::CastOp>(
341
- op.getLoc (), mlir::TypeRange{flattenMemRef },
339
+ op.getLoc (), mlir::TypeRange{flattenedMemRef },
342
340
mlir::ValueRange{adaptor.getAddr ()});
343
341
342
+ auto pointerToMemberTypeToLower = op.getResultTy ();
343
+ // The lowered type of the cir.ptr to the cir.struct member.
344
+ auto memrefToLoweredMemberType =
345
+ typeConverter->convertType (pointerToMemberTypeToLower);
346
+ // Synthesize the byte access to right lowered type.
344
347
auto memberIndex = op.getIndex ();
345
- auto namedTupleType =
346
- mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType ());
347
- // The lowered type of the element to access in the named_tuple.
348
- auto loweredMemberType = namedTupleType.getType (memberIndex);
349
- // memref.view can only cast to another memref. Wrap the target type if it
350
- // is not already a memref (like with a struct with an array member)
351
- mlir::MemRefType elementMemRefTy;
352
- if (mlir::isa<mlir::MemRefType>(loweredMemberType))
353
- elementMemRefTy = mlir::cast<mlir::MemRefType>(loweredMemberType);
354
- else
355
- elementMemRefTy = mlir::MemRefType::get ({}, loweredMemberType);
356
348
auto offset = structLayout->getElementOffset (memberIndex);
357
- // Synthesize the byte access to right lowered type.
358
349
auto byteShift =
359
350
rewriter.create <mlir::arith::ConstantIndexOp>(op.getLoc (), offset);
351
+ // Create the memref pointing to the flattened member location.
360
352
rewriter.replaceOpWithNewOp <mlir::memref::ViewOp>(
361
- op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
353
+ op, memrefToLoweredMemberType, bytesMemRef, byteShift,
354
+ mlir::ValueRange{});
362
355
return mlir::LogicalResult::success ();
363
356
}
364
357
};
@@ -1463,6 +1456,29 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
1463
1456
cirDataLayout);
1464
1457
}
1465
1458
1459
+ namespace {
1460
+ // Lower a cir.array either as a memref when it has a reference semantics or as
1461
+ // a tensor when it has a value semantics (like inside a struct or union)
1462
+ mlir::Type lowerArrayType (cir::ArrayType type, bool hasValueSemantics,
1463
+ mlir::TypeConverter &converter) {
1464
+ SmallVector<int64_t > shape;
1465
+ mlir::Type curType = type;
1466
+ while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1467
+ shape.push_back (arrayType.getSize ());
1468
+ curType = arrayType.getEltType ();
1469
+ }
1470
+ auto elementType = converter.convertType (curType);
1471
+ // FIXME: The element type might not be converted
1472
+ if (!elementType)
1473
+ return nullptr ;
1474
+ // Arrays in C/C++ have a reference semantics when not in a struct, so use
1475
+ // a memref
1476
+ if (hasValueSemantics)
1477
+ return mlir::RankedTensorType::get (shape, elementType);
1478
+ return mlir::MemRefType::get (shape, elementType);
1479
+ }
1480
+ } // namespace
1481
+
1466
1482
mlir::TypeConverter prepareTypeConverter (mlir::DataLayout &dataLayout) {
1467
1483
mlir::TypeConverter converter;
1468
1484
converter.addConversion ([&](cir::PointerType type) -> mlir::Type {
@@ -1471,6 +1487,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1471
1487
if (!ty)
1472
1488
return nullptr ;
1473
1489
if (isa<cir::ArrayType>(type.getPointee ()))
1490
+ // An array is already lowered as a memref with reference semantics
1474
1491
return ty;
1475
1492
return mlir::MemRefType::get ({}, ty);
1476
1493
});
@@ -1510,23 +1527,23 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1510
1527
return mlir::BFloat16Type::get (type.getContext ());
1511
1528
});
1512
1529
converter.addConversion ([&](cir::ArrayType type) -> mlir::Type {
1513
- SmallVector<int64_t > shape;
1514
- mlir::Type curType = type;
1515
- while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1516
- shape.push_back (arrayType.getSize ());
1517
- curType = arrayType.getEltType ();
1518
- }
1519
- auto elementType = converter.convertType (curType);
1520
- // FIXME: The element type might not be converted
1521
- if (!elementType)
1522
- return nullptr ;
1523
- return mlir::MemRefType::get (shape, elementType);
1530
+ // Arrays in C/C++ have a reference semantics when not in a
1531
+ // class/struct/union, so use a memref.
1532
+ return lowerArrayType (type, /* hasValueSemantics */ false , converter);
1524
1533
});
1525
1534
converter.addConversion ([&](cir::VectorType type) -> mlir::Type {
1526
1535
auto ty = converter.convertType (type.getEltType ());
1527
1536
return mlir::VectorType::get (type.getSize (), ty);
1528
1537
});
1529
1538
converter.addConversion ([&](cir::StructType type) -> mlir::Type {
1539
+ auto convertWithValueSemanticsArray = [&](mlir::Type t) {
1540
+ if (mlir::isa<cir::ArrayType>(t))
1541
+ // Inside a class/struct/union, an array has value semantics and is
1542
+ // lowered as a tensor.
1543
+ return lowerArrayType (mlir::cast<cir::ArrayType>(t),
1544
+ /* hasValueSemantics */ true , converter);
1545
+ return converter.convertType (t);
1546
+ };
1530
1547
// FIXME(cir): create separate unions, struct, and classes types.
1531
1548
// Convert struct members.
1532
1549
llvm::SmallVector<mlir::Type> mlirMembers;
@@ -1535,13 +1552,13 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1535
1552
// TODO(cir): This should be properly validated.
1536
1553
case cir::StructType::Struct:
1537
1554
for (auto ty : type.getMembers ())
1538
- mlirMembers.push_back (converter. convertType (ty));
1555
+ mlirMembers.push_back (convertWithValueSemanticsArray (ty));
1539
1556
break ;
1540
1557
// Unions are lowered as only the largest member.
1541
1558
case cir::StructType::Union: {
1542
1559
auto largestMember = type.getLargestMember (dataLayout);
1543
1560
if (largestMember)
1544
- mlirMembers.push_back (converter. convertType (largestMember));
1561
+ mlirMembers.push_back (convertWithValueSemanticsArray (largestMember));
1545
1562
break ;
1546
1563
}
1547
1564
}
0 commit comments