@@ -289,13 +289,13 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
289
289
}
290
290
};
291
291
292
- // Lower cir.get_member
292
+ // Lower cir.get_member by aliasing the result memref to the member inside the
293
+ // flattened structure as a byte array. For example
293
294
//
294
295
// clang-format off
295
- //
296
296
// %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
297
297
//
298
- // to something like
298
+ // is lowered to something like
299
299
//
300
300
// %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
301
301
// %c8 = arith.constant 8 : index
@@ -325,37 +325,30 @@ class CIRGetMemberOpLowering
325
325
// concrete datalayout, both datalayouts are the same.
326
326
auto *structLayout = dataLayout.getStructLayout (structType);
327
327
328
- // Get the lowered type: memref<!named_tuple.named_tuple<>>
329
- auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr ().getType ());
330
328
// Alias the memref of struct to a memref of an i8 array of the same size.
331
329
const std::array linearizedSize{
332
330
static_cast <std::int64_t >(dataLayout.getTypeStoreSize (structType))};
333
- auto flattenMemRef = mlir::MemRefType::get (
334
- linearizedSize, mlir::IntegerType::get (memref. getContext (), 8 ));
331
+ auto flattenedMemRef = mlir::MemRefType::get (
332
+ linearizedSize, mlir::IntegerType::get (getContext (), 8 ));
335
333
// Use a special cast because normal memref cast cannot do such an extreme
336
334
// cast.
337
335
auto bytesMemRef = rewriter.create <mlir::named_tuple::CastOp>(
338
- op.getLoc (), mlir::TypeRange{flattenMemRef },
336
+ op.getLoc (), mlir::TypeRange{flattenedMemRef },
339
337
mlir::ValueRange{adaptor.getAddr ()});
340
338
339
+ auto pointerToMemberTypeToLower = op.getResultTy ();
340
+ // The lowered type of the cir.ptr to the cir.struct member.
341
+ auto memrefToLoweredMemberType =
342
+ typeConverter->convertType (pointerToMemberTypeToLower);
343
+ // Synthesize the byte access to right lowered type.
341
344
auto memberIndex = op.getIndex ();
342
- auto namedTupleType =
343
- mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType ());
344
- // The lowered type of the element to access in the named_tuple.
345
- auto loweredMemberType = namedTupleType.getType (memberIndex);
346
- // memref.view can only cast to another memref. Wrap the target type if it
347
- // is not already a memref (like with a struct with an array member)
348
- mlir::MemRefType elementMemRefTy;
349
- if (mlir::isa<mlir::MemRefType>(loweredMemberType))
350
- elementMemRefTy = mlir::cast<mlir::MemRefType>(loweredMemberType);
351
- else
352
- elementMemRefTy = mlir::MemRefType::get ({}, loweredMemberType);
353
345
auto offset = structLayout->getElementOffset (memberIndex);
354
- // Synthesize the byte access to right lowered type.
355
346
auto byteShift =
356
347
rewriter.create <mlir::arith::ConstantIndexOp>(op.getLoc (), offset);
348
+ // Create the memref pointing to the flattened member location.
357
349
rewriter.replaceOpWithNewOp <mlir::memref::ViewOp>(
358
- op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
350
+ op, memrefToLoweredMemberType, bytesMemRef, byteShift,
351
+ mlir::ValueRange{});
359
352
return mlir::LogicalResult::success ();
360
353
}
361
354
};
@@ -1382,6 +1375,29 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
1382
1375
cirDataLayout);
1383
1376
}
1384
1377
1378
+ namespace {
1379
+ // Lower a cir.array either as a memref when it has a reference semantics or as
1380
+ // a tensor when it has a value semantics (like inside a struct or union)
1381
+ mlir::Type lowerArrayType (cir::ArrayType type, bool hasValueSemantics,
1382
+ mlir::TypeConverter &converter) {
1383
+ SmallVector<int64_t > shape;
1384
+ mlir::Type curType = type;
1385
+ while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1386
+ shape.push_back (arrayType.getSize ());
1387
+ curType = arrayType.getEltType ();
1388
+ }
1389
+ auto elementType = converter.convertType (curType);
1390
+ // FIXME: The element type might not be converted
1391
+ if (!elementType)
1392
+ return nullptr ;
1393
+ // Arrays in C/C++ have a reference semantics when not in a struct, so use
1394
+ // a memref
1395
+ if (hasValueSemantics)
1396
+ return mlir::RankedTensorType::get (shape, elementType);
1397
+ return mlir::MemRefType::get (shape, elementType);
1398
+ }
1399
+ } // namespace
1400
+
1385
1401
mlir::TypeConverter prepareTypeConverter (mlir::DataLayout &dataLayout) {
1386
1402
mlir::TypeConverter converter;
1387
1403
converter.addConversion ([&](cir::PointerType type) -> mlir::Type {
@@ -1390,6 +1406,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1390
1406
if (!ty)
1391
1407
return nullptr ;
1392
1408
if (isa<cir::ArrayType>(type.getPointee ()))
1409
+ // An array is already lowered as a memref with reference semantics
1393
1410
return ty;
1394
1411
return mlir::MemRefType::get ({}, ty);
1395
1412
});
@@ -1429,23 +1446,23 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1429
1446
return mlir::BFloat16Type::get (type.getContext ());
1430
1447
});
1431
1448
converter.addConversion ([&](cir::ArrayType type) -> mlir::Type {
1432
- SmallVector<int64_t > shape;
1433
- mlir::Type curType = type;
1434
- while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1435
- shape.push_back (arrayType.getSize ());
1436
- curType = arrayType.getEltType ();
1437
- }
1438
- auto elementType = converter.convertType (curType);
1439
- // FIXME: The element type might not be converted
1440
- if (!elementType)
1441
- return nullptr ;
1442
- return mlir::MemRefType::get (shape, elementType);
1449
+ // Arrays in C/C++ have a reference semantics when not in a
1450
+ // class/struct/union, so use a memref.
1451
+ return lowerArrayType (type, /* hasValueSemantics */ false , converter);
1443
1452
});
1444
1453
converter.addConversion ([&](cir::VectorType type) -> mlir::Type {
1445
1454
auto ty = converter.convertType (type.getEltType ());
1446
1455
return mlir::VectorType::get (type.getSize (), ty);
1447
1456
});
1448
1457
converter.addConversion ([&](cir::StructType type) -> mlir::Type {
1458
+ auto convertWithValueSemanticsArray = [&](mlir::Type t) {
1459
+ if (mlir::isa<cir::ArrayType>(t))
1460
+ // Inside a class/struct/union, an array has value semantics and is
1461
+ // lowered as a tensor.
1462
+ return lowerArrayType (mlir::cast<cir::ArrayType>(t),
1463
+ /* hasValueSemantics */ true , converter);
1464
+ return converter.convertType (t);
1465
+ };
1449
1466
// FIXME(cir): create separate unions, struct, and classes types.
1450
1467
// Convert struct members.
1451
1468
llvm::SmallVector<mlir::Type> mlirMembers;
@@ -1454,13 +1471,13 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1454
1471
// TODO(cir): This should be properly validated.
1455
1472
case cir::StructType::Struct:
1456
1473
for (auto ty : type.getMembers ())
1457
- mlirMembers.push_back (converter. convertType (ty));
1474
+ mlirMembers.push_back (convertWithValueSemanticsArray (ty));
1458
1475
break ;
1459
1476
// Unions are lowered as only the largest member.
1460
1477
case cir::StructType::Union: {
1461
1478
auto largestMember = type.getLargestMember (dataLayout);
1462
1479
if (largestMember)
1463
- mlirMembers.push_back (converter. convertType (largestMember));
1480
+ mlirMembers.push_back (convertWithValueSemanticsArray (largestMember));
1464
1481
break ;
1465
1482
}
1466
1483
}
0 commit comments