Skip to content

Commit 7adf560

Browse files
committed
[CIR][Lowering] Lower arrays in class/struct/union as tensor
Arrays in C/C++ have usually a reference semantics and can be lowered to memref. But when inside a class/struct/union, arrays hav a value semantics and can be lowered as tensor.
1 parent 7c1b546 commit 7adf560

File tree

2 files changed

+59
-42
lines changed

2 files changed

+59
-42
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -292,13 +292,13 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
292292
}
293293
};
294294

295-
// Lower cir.get_member
295+
// Lower cir.get_member by aliasing the result memref to the member inside the
296+
// flattened structure as a byte array. For example
296297
//
297298
// clang-format off
298-
//
299299
// %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
300300
//
301-
// to something like
301+
// is lowered to something like
302302
//
303303
// %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
304304
// %c8 = arith.constant 8 : index
@@ -328,37 +328,30 @@ class CIRGetMemberOpLowering
328328
// concrete datalayout, both datalayouts are the same.
329329
auto *structLayout = dataLayout.getStructLayout(structType);
330330

331-
// Get the lowered type: memref<!named_tuple.named_tuple<>>
332-
auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr().getType());
333331
// Alias the memref of struct to a memref of an i8 array of the same size.
334332
const std::array linearizedSize{
335333
static_cast<std::int64_t>(dataLayout.getTypeStoreSize(structType))};
336-
auto flattenMemRef = mlir::MemRefType::get(
337-
linearizedSize, mlir::IntegerType::get(memref.getContext(), 8));
334+
auto flattenedMemRef = mlir::MemRefType::get(
335+
linearizedSize, mlir::IntegerType::get(getContext(), 8));
338336
// Use a special cast because normal memref cast cannot do such an extreme
339337
// cast.
340338
auto bytesMemRef = rewriter.create<mlir::named_tuple::CastOp>(
341-
op.getLoc(), mlir::TypeRange{flattenMemRef},
339+
op.getLoc(), mlir::TypeRange{flattenedMemRef},
342340
mlir::ValueRange{adaptor.getAddr()});
343341

342+
auto pointerToMemberTypeToLower = op.getResultTy();
343+
// The lowered type of the cir.ptr to the cir.struct member.
344+
auto memrefToLoweredMemberType =
345+
typeConverter->convertType(pointerToMemberTypeToLower);
346+
// Synthesize the byte access to right lowered type.
344347
auto memberIndex = op.getIndex();
345-
auto namedTupleType =
346-
mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType());
347-
// The lowered type of the element to access in the named_tuple.
348-
auto loweredMemberType = namedTupleType.getType(memberIndex);
349-
// memref.view can only cast to another memref. Wrap the target type if it
350-
// is not already a memref (like with a struct with an array member)
351-
mlir::MemRefType elementMemRefTy;
352-
if (mlir::isa<mlir::MemRefType>(loweredMemberType))
353-
elementMemRefTy = mlir::cast<mlir::MemRefType>(loweredMemberType);
354-
else
355-
elementMemRefTy = mlir::MemRefType::get({}, loweredMemberType);
356348
auto offset = structLayout->getElementOffset(memberIndex);
357-
// Synthesize the byte access to right lowered type.
358349
auto byteShift =
359350
rewriter.create<mlir::arith::ConstantIndexOp>(op.getLoc(), offset);
351+
// Create the memref pointing to the flattened member location.
360352
rewriter.replaceOpWithNewOp<mlir::memref::ViewOp>(
361-
op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
353+
op, memrefToLoweredMemberType, bytesMemRef, byteShift,
354+
mlir::ValueRange{});
362355
return mlir::LogicalResult::success();
363356
}
364357
};
@@ -1463,6 +1456,29 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
14631456
cirDataLayout);
14641457
}
14651458

1459+
namespace {
1460+
// Lower a cir.array either as a memref when it has a reference semantics or as
1461+
// a tensor when it has a value semantics (like inside a struct or union)
1462+
mlir::Type lowerArrayType(cir::ArrayType type, bool hasValueSemantics,
1463+
mlir::TypeConverter &converter) {
1464+
SmallVector<int64_t> shape;
1465+
mlir::Type curType = type;
1466+
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1467+
shape.push_back(arrayType.getSize());
1468+
curType = arrayType.getEltType();
1469+
}
1470+
auto elementType = converter.convertType(curType);
1471+
// FIXME: The element type might not be converted
1472+
if (!elementType)
1473+
return nullptr;
1474+
// Arrays in C/C++ have a reference semantics when not in a struct, so use
1475+
// a memref
1476+
if (hasValueSemantics)
1477+
return mlir::RankedTensorType::get(shape, elementType);
1478+
return mlir::MemRefType::get(shape, elementType);
1479+
}
1480+
} // namespace
1481+
14661482
mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14671483
mlir::TypeConverter converter;
14681484
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
@@ -1471,6 +1487,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14711487
if (!ty)
14721488
return nullptr;
14731489
if (isa<cir::ArrayType>(type.getPointee()))
1490+
// An array is already lowered as a memref with reference semantics
14741491
return ty;
14751492
return mlir::MemRefType::get({}, ty);
14761493
});
@@ -1510,23 +1527,23 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
15101527
return mlir::BFloat16Type::get(type.getContext());
15111528
});
15121529
converter.addConversion([&](cir::ArrayType type) -> mlir::Type {
1513-
SmallVector<int64_t> shape;
1514-
mlir::Type curType = type;
1515-
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1516-
shape.push_back(arrayType.getSize());
1517-
curType = arrayType.getEltType();
1518-
}
1519-
auto elementType = converter.convertType(curType);
1520-
// FIXME: The element type might not be converted
1521-
if (!elementType)
1522-
return nullptr;
1523-
return mlir::MemRefType::get(shape, elementType);
1530+
// Arrays in C/C++ have a reference semantics when not in a
1531+
// class/struct/union, so use a memref.
1532+
return lowerArrayType(type, /* hasValueSemantics */ false, converter);
15241533
});
15251534
converter.addConversion([&](cir::VectorType type) -> mlir::Type {
15261535
auto ty = converter.convertType(type.getEltType());
15271536
return mlir::VectorType::get(type.getSize(), ty);
15281537
});
15291538
converter.addConversion([&](cir::StructType type) -> mlir::Type {
1539+
auto convertWithValueSemanticsArray = [&](mlir::Type t) {
1540+
if (mlir::isa<cir::ArrayType>(t))
1541+
// Inside a class/struct/union, an array has value semantics and is
1542+
// lowered as a tensor.
1543+
return lowerArrayType(mlir::cast<cir::ArrayType>(t),
1544+
/* hasValueSemantics */ true, converter);
1545+
return converter.convertType(t);
1546+
};
15301547
// FIXME(cir): create separate unions, struct, and classes types.
15311548
// Convert struct members.
15321549
llvm::SmallVector<mlir::Type> mlirMembers;
@@ -1535,13 +1552,13 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
15351552
// TODO(cir): This should be properly validated.
15361553
case cir::StructType::Struct:
15371554
for (auto ty : type.getMembers())
1538-
mlirMembers.push_back(converter.convertType(ty));
1555+
mlirMembers.push_back(convertWithValueSemanticsArray(ty));
15391556
break;
15401557
// Unions are lowered as only the largest member.
15411558
case cir::StructType::Union: {
15421559
auto largestMember = type.getLargestMember(dataLayout);
15431560
if (largestMember)
1544-
mlirMembers.push_back(converter.convertType(largestMember));
1561+
mlirMembers.push_back(convertWithValueSemanticsArray(largestMember));
15451562
break;
15461563
}
15471564
}

clang/test/CIR/Lowering/ThroughMLIR/struct.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,40 +11,40 @@ struct s {
1111

1212
int main() {
1313
s v;
14-
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>>
14+
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>>
1515
v.a = 7;
1616
// CHECK: %[[C_7:.+]] = arith.constant 7 : i32
17-
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
17+
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
1818
// CHECK: %[[OFFSET_A:.+]] = arith.constant 0 : index
1919
// CHECK: %[[VIEW_A:.+]] = memref.view %[[I8_EQUIV_A]][%[[OFFSET_A]]][] : memref<40xi8> to memref<i32>
2020
// CHECK: memref.store %[[C_7]], %[[VIEW_A]][] : memref<i32>
2121

2222
v.b = 3.;
2323
// CHECK: %[[C_3:.+]] = arith.constant 3.000000e+00 : f64
24-
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
24+
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
2525
// CHECK: %[[OFFSET_B:.+]] = arith.constant 8 : index
2626
// CHECK: %[[VIEW_B:.+]] = memref.view %[[I8_EQUIV_B]][%[[OFFSET_B]]][] : memref<40xi8> to memref<f64>
2727
// CHECK: memref.store %[[C_3]], %[[VIEW_B]][] : memref<f64>
2828

2929
v.c = 'z';
3030
// CHECK: %[[C_122:.+]] = arith.constant 122 : i8
31-
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
31+
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
3232
// CHECK: %[[OFFSET_C:.+]] = arith.constant 16 : index
3333
// CHECK: %[[VIEW_C:.+]] = memref.view %[[I8_EQUIV_C]][%[[OFFSET_C]]][] : memref<40xi8> to memref<i8>
3434
// memref.store %[[C_122]], %[[VIEW_C]][] : memref<i8>
3535

36+
auto& a = v.d;
3637
v.d[4] = 6.f;
3738
// CHECK: %[[C_6:.+]] = arith.constant 6.000000e+00 : f32
38-
// CHECK: %[[I8_EQUIV_D:.+]] = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
39+
// CHECK: %[[I8_EQUIV_D:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
3940
// CHECK: %[[OFFSET_D:.+]] = arith.constant 20 : index
40-
// Do not lower to a memref of memref
41-
// CHECK: %[[VIEW_D:.+]] = memref.view %3[%c20][] : memref<40xi8> to memref<5xf32>
41+
// CHECK: %[[VIEW_D:.+]] = memref.view %[[I8_EQUIV_D]][%[[OFFSET_D]]][] : memref<40xi8> to memref<5xf32>
4242
// CHECK: %[[C_4:.+]] = arith.constant 4 : i32
4343
// CHECK: %[[I_D:.+]] = arith.index_cast %[[C_4]] : i32 to index
4444
// CHECK: memref.store %[[C_6]], %[[VIEW_D]][%[[I_D]]] : memref<5xf32>
4545

4646
return v.c;
47-
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
47+
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
4848
// CHECK: %[[OFFSET_C_1:.+]] = arith.constant 16 : index
4949
// CHECK: %[[VIEW_C_1:.+]] = memref.view %[[I8_EQUIV_C_1]][%[[OFFSET_C_1]]][] : memref<40xi8> to memref<i8>
5050
// CHECK: %[[VALUE_C:.+]] = memref.load %[[VIEW_C_1]][] : memref<i8>

0 commit comments

Comments
 (0)