Skip to content

Commit e725880

Browse files
committed
[CIR][Lowering] Lower arrays in class/struct/union as tensor
Arrays in C/C++ have usually a reference semantics and can be lowered to memref. But when inside a class/struct/union, arrays hav a value semantics and can be lowered as tensor.
1 parent a3fa866 commit e725880

File tree

2 files changed

+59
-42
lines changed

2 files changed

+59
-42
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -293,13 +293,13 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
293293
}
294294
};
295295

296-
// Lower cir.get_member
296+
// Lower cir.get_member by aliasing the result memref to the member inside the
297+
// flattened structure as a byte array. For example
297298
//
298299
// clang-format off
299-
//
300300
// %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
301301
//
302-
// to something like
302+
// is lowered to something like
303303
//
304304
// %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
305305
// %c8 = arith.constant 8 : index
@@ -329,37 +329,30 @@ class CIRGetMemberOpLowering
329329
// concrete datalayout, both datalayouts are the same.
330330
auto *structLayout = dataLayout.getStructLayout(structType);
331331

332-
// Get the lowered type: memref<!named_tuple.named_tuple<>>
333-
auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr().getType());
334332
// Alias the memref of struct to a memref of an i8 array of the same size.
335333
const std::array linearizedSize{
336334
static_cast<std::int64_t>(dataLayout.getTypeStoreSize(structType))};
337-
auto flattenMemRef = mlir::MemRefType::get(
338-
linearizedSize, mlir::IntegerType::get(memref.getContext(), 8));
335+
auto flattenedMemRef = mlir::MemRefType::get(
336+
linearizedSize, mlir::IntegerType::get(getContext(), 8));
339337
// Use a special cast because normal memref cast cannot do such an extreme
340338
// cast.
341339
auto bytesMemRef = rewriter.create<mlir::named_tuple::CastOp>(
342-
op.getLoc(), mlir::TypeRange{flattenMemRef},
340+
op.getLoc(), mlir::TypeRange{flattenedMemRef},
343341
mlir::ValueRange{adaptor.getAddr()});
344342

343+
auto pointerToMemberTypeToLower = op.getResultTy();
344+
// The lowered type of the cir.ptr to the cir.struct member.
345+
auto memrefToLoweredMemberType =
346+
typeConverter->convertType(pointerToMemberTypeToLower);
347+
// Synthesize the byte access to right lowered type.
345348
auto memberIndex = op.getIndex();
346-
auto namedTupleType =
347-
mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType());
348-
// The lowered type of the element to access in the named_tuple.
349-
auto loweredMemberType = namedTupleType.getType(memberIndex);
350-
// memref.view can only cast to another memref. Wrap the target type if it
351-
// is not already a memref (like with a struct with an array member)
352-
mlir::MemRefType elementMemRefTy;
353-
if (mlir::isa<mlir::MemRefType>(loweredMemberType))
354-
elementMemRefTy = mlir::cast<mlir::MemRefType>(loweredMemberType);
355-
else
356-
elementMemRefTy = mlir::MemRefType::get({}, loweredMemberType);
357349
auto offset = structLayout->getElementOffset(memberIndex);
358-
// Synthesize the byte access to right lowered type.
359350
auto byteShift =
360351
rewriter.create<mlir::arith::ConstantIndexOp>(op.getLoc(), offset);
352+
// Create the memref pointing to the flattened member location.
361353
rewriter.replaceOpWithNewOp<mlir::memref::ViewOp>(
362-
op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
354+
op, memrefToLoweredMemberType, bytesMemRef, byteShift,
355+
mlir::ValueRange{});
363356
return mlir::LogicalResult::success();
364357
}
365358
};
@@ -1462,6 +1455,29 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
14621455
cirDataLayout);
14631456
}
14641457

1458+
namespace {
1459+
// Lower a cir.array either as a memref when it has a reference semantics or as
1460+
// a tensor when it has a value semantics (like inside a struct or union)
1461+
mlir::Type lowerArrayType(cir::ArrayType type, bool hasValueSemantics,
1462+
mlir::TypeConverter &converter) {
1463+
SmallVector<int64_t> shape;
1464+
mlir::Type curType = type;
1465+
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1466+
shape.push_back(arrayType.getSize());
1467+
curType = arrayType.getEltType();
1468+
}
1469+
auto elementType = converter.convertType(curType);
1470+
// FIXME: The element type might not be converted
1471+
if (!elementType)
1472+
return nullptr;
1473+
// Arrays in C/C++ have a reference semantics when not in a struct, so use
1474+
// a memref
1475+
if (hasValueSemantics)
1476+
return mlir::RankedTensorType::get(shape, elementType);
1477+
return mlir::MemRefType::get(shape, elementType);
1478+
}
1479+
} // namespace
1480+
14651481
mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14661482
mlir::TypeConverter converter;
14671483
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
@@ -1470,6 +1486,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14701486
if (!ty)
14711487
return nullptr;
14721488
if (isa<cir::ArrayType>(type.getPointee()))
1489+
// An array is already lowered as a memref with reference semantics
14731490
return ty;
14741491
return mlir::MemRefType::get({}, ty);
14751492
});
@@ -1509,23 +1526,23 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
15091526
return mlir::BFloat16Type::get(type.getContext());
15101527
});
15111528
converter.addConversion([&](cir::ArrayType type) -> mlir::Type {
1512-
SmallVector<int64_t> shape;
1513-
mlir::Type curType = type;
1514-
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1515-
shape.push_back(arrayType.getSize());
1516-
curType = arrayType.getEltType();
1517-
}
1518-
auto elementType = converter.convertType(curType);
1519-
// FIXME: The element type might not be converted
1520-
if (!elementType)
1521-
return nullptr;
1522-
return mlir::MemRefType::get(shape, elementType);
1529+
// Arrays in C/C++ have a reference semantics when not in a
1530+
// class/struct/union, so use a memref.
1531+
return lowerArrayType(type, /* hasValueSemantics */ false, converter);
15231532
});
15241533
converter.addConversion([&](cir::VectorType type) -> mlir::Type {
15251534
auto ty = converter.convertType(type.getEltType());
15261535
return mlir::VectorType::get(type.getSize(), ty);
15271536
});
15281537
converter.addConversion([&](cir::StructType type) -> mlir::Type {
1538+
auto convertWithValueSemanticsArray = [&](mlir::Type t) {
1539+
if (mlir::isa<cir::ArrayType>(t))
1540+
// Inside a class/struct/union, an array has value semantics and is
1541+
// lowered as a tensor.
1542+
return lowerArrayType(mlir::cast<cir::ArrayType>(t),
1543+
/* hasValueSemantics */ true, converter);
1544+
return converter.convertType(t);
1545+
};
15291546
// FIXME(cir): create separate unions, struct, and classes types.
15301547
// Convert struct members.
15311548
llvm::SmallVector<mlir::Type> mlirMembers;
@@ -1534,13 +1551,13 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
15341551
// TODO(cir): This should be properly validated.
15351552
case cir::StructType::Struct:
15361553
for (auto ty : type.getMembers())
1537-
mlirMembers.push_back(converter.convertType(ty));
1554+
mlirMembers.push_back(convertWithValueSemanticsArray(ty));
15381555
break;
15391556
// Unions are lowered as only the largest member.
15401557
case cir::StructType::Union: {
15411558
auto largestMember = type.getLargestMember(dataLayout);
15421559
if (largestMember)
1543-
mlirMembers.push_back(converter.convertType(largestMember));
1560+
mlirMembers.push_back(convertWithValueSemanticsArray(largestMember));
15441561
break;
15451562
}
15461563
}

clang/test/CIR/Lowering/ThroughMLIR/struct.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,40 +11,40 @@ struct s {
1111

1212
int main() {
1313
s v;
14-
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>>
14+
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>>
1515
v.a = 7;
1616
// CHECK: %[[C_7:.+]] = arith.constant 7 : i32
17-
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
17+
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
1818
// CHECK: %[[OFFSET_A:.+]] = arith.constant 0 : index
1919
// CHECK: %[[VIEW_A:.+]] = memref.view %[[I8_EQUIV_A]][%[[OFFSET_A]]][] : memref<40xi8> to memref<i32>
2020
// CHECK: memref.store %[[C_7]], %[[VIEW_A]][] : memref<i32>
2121

2222
v.b = 3.;
2323
// CHECK: %[[C_3:.+]] = arith.constant 3.000000e+00 : f64
24-
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
24+
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
2525
// CHECK: %[[OFFSET_B:.+]] = arith.constant 8 : index
2626
// CHECK: %[[VIEW_B:.+]] = memref.view %[[I8_EQUIV_B]][%[[OFFSET_B]]][] : memref<40xi8> to memref<f64>
2727
// CHECK: memref.store %[[C_3]], %[[VIEW_B]][] : memref<f64>
2828

2929
v.c = 'z';
3030
// CHECK: %[[C_122:.+]] = arith.constant 122 : i8
31-
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
31+
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
3232
// CHECK: %[[OFFSET_C:.+]] = arith.constant 16 : index
3333
// CHECK: %[[VIEW_C:.+]] = memref.view %[[I8_EQUIV_C]][%[[OFFSET_C]]][] : memref<40xi8> to memref<i8>
3434
// memref.store %[[C_122]], %[[VIEW_C]][] : memref<i8>
3535

36+
auto& a = v.d;
3637
v.d[4] = 6.f;
3738
// CHECK: %[[C_6:.+]] = arith.constant 6.000000e+00 : f32
38-
// CHECK: %[[I8_EQUIV_D:.+]] = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
39+
// CHECK: %[[I8_EQUIV_D:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
3940
// CHECK: %[[OFFSET_D:.+]] = arith.constant 20 : index
40-
// Do not lower to a memref of memref
41-
// CHECK: %[[VIEW_D:.+]] = memref.view %3[%c20][] : memref<40xi8> to memref<5xf32>
41+
// CHECK: %[[VIEW_D:.+]] = memref.view %[[I8_EQUIV_D]][%[[OFFSET_D]]][] : memref<40xi8> to memref<5xf32>
4242
// CHECK: %[[C_4:.+]] = arith.constant 4 : i32
4343
// CHECK: %[[I_D:.+]] = arith.index_cast %[[C_4]] : i32 to index
4444
// CHECK: memref.store %[[C_6]], %[[VIEW_D]][%[[I_D]]] : memref<5xf32>
4545

4646
return v.c;
47-
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
47+
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
4848
// CHECK: %[[OFFSET_C_1:.+]] = arith.constant 16 : index
4949
// CHECK: %[[VIEW_C_1:.+]] = memref.view %[[I8_EQUIV_C_1]][%[[OFFSET_C_1]]][] : memref<40xi8> to memref<i8>
5050
// CHECK: %[[VALUE_C:.+]] = memref.load %[[VIEW_C_1]][] : memref<i8>

0 commit comments

Comments
 (0)