@@ -293,13 +293,13 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
293
293
}
294
294
};
295
295
296
- // Lower cir.get_member
296
+ // Lower cir.get_member by aliasing the result memref to the member inside the
297
+ // flattened structure as a byte array. For example
297
298
//
298
299
// clang-format off
299
- //
300
300
// %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
301
301
//
302
- // to something like
302
+ // is lowered to something like
303
303
//
304
304
// %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
305
305
// %c8 = arith.constant 8 : index
@@ -329,37 +329,30 @@ class CIRGetMemberOpLowering
329
329
// concrete datalayout, both datalayouts are the same.
330
330
auto *structLayout = dataLayout.getStructLayout (structType);
331
331
332
- // Get the lowered type: memref<!named_tuple.named_tuple<>>
333
- auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr ().getType ());
334
332
// Alias the memref of struct to a memref of an i8 array of the same size.
335
333
const std::array linearizedSize{
336
334
static_cast <std::int64_t >(dataLayout.getTypeStoreSize (structType))};
337
- auto flattenMemRef = mlir::MemRefType::get (
338
- linearizedSize, mlir::IntegerType::get (memref. getContext (), 8 ));
335
+ auto flattenedMemRef = mlir::MemRefType::get (
336
+ linearizedSize, mlir::IntegerType::get (getContext (), 8 ));
339
337
// Use a special cast because normal memref cast cannot do such an extreme
340
338
// cast.
341
339
auto bytesMemRef = rewriter.create <mlir::named_tuple::CastOp>(
342
- op.getLoc (), mlir::TypeRange{flattenMemRef },
340
+ op.getLoc (), mlir::TypeRange{flattenedMemRef },
343
341
mlir::ValueRange{adaptor.getAddr ()});
344
342
343
+ auto pointerToMemberTypeToLower = op.getResultTy ();
344
+ // The lowered type of the cir.ptr to the cir.struct member.
345
+ auto memrefToLoweredMemberType =
346
+ typeConverter->convertType (pointerToMemberTypeToLower);
347
+ // Synthesize the byte access to right lowered type.
345
348
auto memberIndex = op.getIndex ();
346
- auto namedTupleType =
347
- mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType ());
348
- // The lowered type of the element to access in the named_tuple.
349
- auto loweredMemberType = namedTupleType.getType (memberIndex);
350
- // memref.view can only cast to another memref. Wrap the target type if it
351
- // is not already a memref (like with a struct with an array member)
352
- mlir::MemRefType elementMemRefTy;
353
- if (mlir::isa<mlir::MemRefType>(loweredMemberType))
354
- elementMemRefTy = mlir::cast<mlir::MemRefType>(loweredMemberType);
355
- else
356
- elementMemRefTy = mlir::MemRefType::get ({}, loweredMemberType);
357
349
auto offset = structLayout->getElementOffset (memberIndex);
358
- // Synthesize the byte access to right lowered type.
359
350
auto byteShift =
360
351
rewriter.create <mlir::arith::ConstantIndexOp>(op.getLoc (), offset);
352
+ // Create the memref pointing to the flattened member location.
361
353
rewriter.replaceOpWithNewOp <mlir::memref::ViewOp>(
362
- op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
354
+ op, memrefToLoweredMemberType, bytesMemRef, byteShift,
355
+ mlir::ValueRange{});
363
356
return mlir::LogicalResult::success ();
364
357
}
365
358
};
@@ -1462,6 +1455,29 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
1462
1455
cirDataLayout);
1463
1456
}
1464
1457
1458
+ namespace {
1459
+ // Lower a cir.array either as a memref when it has a reference semantics or as
1460
+ // a tensor when it has a value semantics (like inside a struct or union)
1461
+ mlir::Type lowerArrayType (cir::ArrayType type, bool hasValueSemantics,
1462
+ mlir::TypeConverter &converter) {
1463
+ SmallVector<int64_t > shape;
1464
+ mlir::Type curType = type;
1465
+ while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1466
+ shape.push_back (arrayType.getSize ());
1467
+ curType = arrayType.getEltType ();
1468
+ }
1469
+ auto elementType = converter.convertType (curType);
1470
+ // FIXME: The element type might not be converted
1471
+ if (!elementType)
1472
+ return nullptr ;
1473
+ // Arrays in C/C++ have a reference semantics when not in a struct, so use
1474
+ // a memref
1475
+ if (hasValueSemantics)
1476
+ return mlir::RankedTensorType::get (shape, elementType);
1477
+ return mlir::MemRefType::get (shape, elementType);
1478
+ }
1479
+ } // namespace
1480
+
1465
1481
mlir::TypeConverter prepareTypeConverter (mlir::DataLayout &dataLayout) {
1466
1482
mlir::TypeConverter converter;
1467
1483
converter.addConversion ([&](cir::PointerType type) -> mlir::Type {
@@ -1470,6 +1486,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1470
1486
if (!ty)
1471
1487
return nullptr ;
1472
1488
if (isa<cir::ArrayType>(type.getPointee ()))
1489
+ // An array is already lowered as a memref with reference semantics
1473
1490
return ty;
1474
1491
return mlir::MemRefType::get ({}, ty);
1475
1492
});
@@ -1509,23 +1526,23 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1509
1526
return mlir::BFloat16Type::get (type.getContext ());
1510
1527
});
1511
1528
converter.addConversion ([&](cir::ArrayType type) -> mlir::Type {
1512
- SmallVector<int64_t > shape;
1513
- mlir::Type curType = type;
1514
- while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1515
- shape.push_back (arrayType.getSize ());
1516
- curType = arrayType.getEltType ();
1517
- }
1518
- auto elementType = converter.convertType (curType);
1519
- // FIXME: The element type might not be converted
1520
- if (!elementType)
1521
- return nullptr ;
1522
- return mlir::MemRefType::get (shape, elementType);
1529
+ // Arrays in C/C++ have a reference semantics when not in a
1530
+ // class/struct/union, so use a memref.
1531
+ return lowerArrayType (type, /* hasValueSemantics */ false , converter);
1523
1532
});
1524
1533
converter.addConversion ([&](cir::VectorType type) -> mlir::Type {
1525
1534
auto ty = converter.convertType (type.getEltType ());
1526
1535
return mlir::VectorType::get (type.getSize (), ty);
1527
1536
});
1528
1537
converter.addConversion ([&](cir::StructType type) -> mlir::Type {
1538
+ auto convertWithValueSemanticsArray = [&](mlir::Type t) {
1539
+ if (mlir::isa<cir::ArrayType>(t))
1540
+ // Inside a class/struct/union, an array has value semantics and is
1541
+ // lowered as a tensor.
1542
+ return lowerArrayType (mlir::cast<cir::ArrayType>(t),
1543
+ /* hasValueSemantics */ true , converter);
1544
+ return converter.convertType (t);
1545
+ };
1529
1546
// FIXME(cir): create separate unions, struct, and classes types.
1530
1547
// Convert struct members.
1531
1548
llvm::SmallVector<mlir::Type> mlirMembers;
@@ -1534,13 +1551,13 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1534
1551
// TODO(cir): This should be properly validated.
1535
1552
case cir::StructType::Struct:
1536
1553
for (auto ty : type.getMembers ())
1537
- mlirMembers.push_back (converter. convertType (ty));
1554
+ mlirMembers.push_back (convertWithValueSemanticsArray (ty));
1538
1555
break ;
1539
1556
// Unions are lowered as only the largest member.
1540
1557
case cir::StructType::Union: {
1541
1558
auto largestMember = type.getLargestMember (dataLayout);
1542
1559
if (largestMember)
1543
- mlirMembers.push_back (converter. convertType (largestMember));
1560
+ mlirMembers.push_back (convertWithValueSemanticsArray (largestMember));
1544
1561
break ;
1545
1562
}
1546
1563
}
0 commit comments