Skip to content

Commit a967bc7

Browse files
committed
[CIR][Lowering] Lower arrays in class/struct/union as tensor
Arrays in C/C++ have usually a reference semantics and can be lowered to memref. But when inside a class/struct/union, arrays hav a value semantics and can be lowered as tensor.
1 parent 64a0abd commit a967bc7

File tree

2 files changed

+59
-42
lines changed

2 files changed

+59
-42
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -289,13 +289,13 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
289289
}
290290
};
291291

292-
// Lower cir.get_member
292+
// Lower cir.get_member by aliasing the result memref to the member inside the
293+
// flattened structure as a byte array. For example
293294
//
294295
// clang-format off
295-
//
296296
// %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
297297
//
298-
// to something like
298+
// is lowered to something like
299299
//
300300
// %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
301301
// %c8 = arith.constant 8 : index
@@ -325,37 +325,30 @@ class CIRGetMemberOpLowering
325325
// concrete datalayout, both datalayouts are the same.
326326
auto *structLayout = dataLayout.getStructLayout(structType);
327327

328-
// Get the lowered type: memref<!named_tuple.named_tuple<>>
329-
auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr().getType());
330328
// Alias the memref of struct to a memref of an i8 array of the same size.
331329
const std::array linearizedSize{
332330
static_cast<std::int64_t>(dataLayout.getTypeStoreSize(structType))};
333-
auto flattenMemRef = mlir::MemRefType::get(
334-
linearizedSize, mlir::IntegerType::get(memref.getContext(), 8));
331+
auto flattenedMemRef = mlir::MemRefType::get(
332+
linearizedSize, mlir::IntegerType::get(getContext(), 8));
335333
// Use a special cast because normal memref cast cannot do such an extreme
336334
// cast.
337335
auto bytesMemRef = rewriter.create<mlir::named_tuple::CastOp>(
338-
op.getLoc(), mlir::TypeRange{flattenMemRef},
336+
op.getLoc(), mlir::TypeRange{flattenedMemRef},
339337
mlir::ValueRange{adaptor.getAddr()});
340338

339+
auto pointerToMemberTypeToLower = op.getResultTy();
340+
// The lowered type of the cir.ptr to the cir.struct member.
341+
auto memrefToLoweredMemberType =
342+
typeConverter->convertType(pointerToMemberTypeToLower);
343+
// Synthesize the byte access to right lowered type.
341344
auto memberIndex = op.getIndex();
342-
auto namedTupleType =
343-
mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType());
344-
// The lowered type of the element to access in the named_tuple.
345-
auto loweredMemberType = namedTupleType.getType(memberIndex);
346-
// memref.view can only cast to another memref. Wrap the target type if it
347-
// is not already a memref (like with a struct with an array member)
348-
mlir::MemRefType elementMemRefTy;
349-
if (mlir::isa<mlir::MemRefType>(loweredMemberType))
350-
elementMemRefTy = mlir::cast<mlir::MemRefType>(loweredMemberType);
351-
else
352-
elementMemRefTy = mlir::MemRefType::get({}, loweredMemberType);
353345
auto offset = structLayout->getElementOffset(memberIndex);
354-
// Synthesize the byte access to right lowered type.
355346
auto byteShift =
356347
rewriter.create<mlir::arith::ConstantIndexOp>(op.getLoc(), offset);
348+
// Create the memref pointing to the flattened member location.
357349
rewriter.replaceOpWithNewOp<mlir::memref::ViewOp>(
358-
op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
350+
op, memrefToLoweredMemberType, bytesMemRef, byteShift,
351+
mlir::ValueRange{});
359352
return mlir::LogicalResult::success();
360353
}
361354
};
@@ -1382,6 +1375,29 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
13821375
cirDataLayout);
13831376
}
13841377

1378+
namespace {
1379+
// Lower a cir.array either as a memref when it has a reference semantics or as
1380+
// a tensor when it has a value semantics (like inside a struct or union)
1381+
mlir::Type lowerArrayType(cir::ArrayType type, bool hasValueSemantics,
1382+
mlir::TypeConverter &converter) {
1383+
SmallVector<int64_t> shape;
1384+
mlir::Type curType = type;
1385+
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1386+
shape.push_back(arrayType.getSize());
1387+
curType = arrayType.getEltType();
1388+
}
1389+
auto elementType = converter.convertType(curType);
1390+
// FIXME: The element type might not be converted
1391+
if (!elementType)
1392+
return nullptr;
1393+
// Arrays in C/C++ have a reference semantics when not in a struct, so use
1394+
// a memref
1395+
if (hasValueSemantics)
1396+
return mlir::RankedTensorType::get(shape, elementType);
1397+
return mlir::MemRefType::get(shape, elementType);
1398+
}
1399+
} // namespace
1400+
13851401
mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
13861402
mlir::TypeConverter converter;
13871403
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
@@ -1390,6 +1406,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
13901406
if (!ty)
13911407
return nullptr;
13921408
if (isa<cir::ArrayType>(type.getPointee()))
1409+
// An array is already lowered as a memref with reference semantics
13931410
return ty;
13941411
return mlir::MemRefType::get({}, ty);
13951412
});
@@ -1429,23 +1446,23 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14291446
return mlir::BFloat16Type::get(type.getContext());
14301447
});
14311448
converter.addConversion([&](cir::ArrayType type) -> mlir::Type {
1432-
SmallVector<int64_t> shape;
1433-
mlir::Type curType = type;
1434-
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1435-
shape.push_back(arrayType.getSize());
1436-
curType = arrayType.getEltType();
1437-
}
1438-
auto elementType = converter.convertType(curType);
1439-
// FIXME: The element type might not be converted
1440-
if (!elementType)
1441-
return nullptr;
1442-
return mlir::MemRefType::get(shape, elementType);
1449+
// Arrays in C/C++ have a reference semantics when not in a
1450+
// class/struct/union, so use a memref.
1451+
return lowerArrayType(type, /* hasValueSemantics */ false, converter);
14431452
});
14441453
converter.addConversion([&](cir::VectorType type) -> mlir::Type {
14451454
auto ty = converter.convertType(type.getEltType());
14461455
return mlir::VectorType::get(type.getSize(), ty);
14471456
});
14481457
converter.addConversion([&](cir::StructType type) -> mlir::Type {
1458+
auto convertWithValueSemanticsArray = [&](mlir::Type t) {
1459+
if (mlir::isa<cir::ArrayType>(t))
1460+
// Inside a class/struct/union, an array has value semantics and is
1461+
// lowered as a tensor.
1462+
return lowerArrayType(mlir::cast<cir::ArrayType>(t),
1463+
/* hasValueSemantics */ true, converter);
1464+
return converter.convertType(t);
1465+
};
14491466
// FIXME(cir): create separate unions, struct, and classes types.
14501467
// Convert struct members.
14511468
llvm::SmallVector<mlir::Type> mlirMembers;
@@ -1454,13 +1471,13 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14541471
// TODO(cir): This should be properly validated.
14551472
case cir::StructType::Struct:
14561473
for (auto ty : type.getMembers())
1457-
mlirMembers.push_back(converter.convertType(ty));
1474+
mlirMembers.push_back(convertWithValueSemanticsArray(ty));
14581475
break;
14591476
// Unions are lowered as only the largest member.
14601477
case cir::StructType::Union: {
14611478
auto largestMember = type.getLargestMember(dataLayout);
14621479
if (largestMember)
1463-
mlirMembers.push_back(converter.convertType(largestMember));
1480+
mlirMembers.push_back(convertWithValueSemanticsArray(largestMember));
14641481
break;
14651482
}
14661483
}

clang/test/CIR/Lowering/ThroughMLIR/struct.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,40 +11,40 @@ struct s {
1111

1212
int main() {
1313
s v;
14-
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>>
14+
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>>
1515
v.a = 7;
1616
// CHECK: %[[C_7:.+]] = arith.constant 7 : i32
17-
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
17+
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
1818
// CHECK: %[[OFFSET_A:.+]] = arith.constant 0 : index
1919
// CHECK: %[[VIEW_A:.+]] = memref.view %[[I8_EQUIV_A]][%[[OFFSET_A]]][] : memref<40xi8> to memref<i32>
2020
// CHECK: memref.store %[[C_7]], %[[VIEW_A]][] : memref<i32>
2121

2222
v.b = 3.;
2323
// CHECK: %[[C_3:.+]] = arith.constant 3.000000e+00 : f64
24-
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
24+
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
2525
// CHECK: %[[OFFSET_B:.+]] = arith.constant 8 : index
2626
// CHECK: %[[VIEW_B:.+]] = memref.view %[[I8_EQUIV_B]][%[[OFFSET_B]]][] : memref<40xi8> to memref<f64>
2727
// CHECK: memref.store %[[C_3]], %[[VIEW_B]][] : memref<f64>
2828

2929
v.c = 'z';
3030
// CHECK: %[[C_122:.+]] = arith.constant 122 : i8
31-
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
31+
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
3232
// CHECK: %[[OFFSET_C:.+]] = arith.constant 16 : index
3333
// CHECK: %[[VIEW_C:.+]] = memref.view %[[I8_EQUIV_C]][%[[OFFSET_C]]][] : memref<40xi8> to memref<i8>
3434
// memref.store %[[C_122]], %[[VIEW_C]][] : memref<i8>
3535

36+
auto& a = v.d;
3637
v.d[4] = 6.f;
3738
// CHECK: %[[C_6:.+]] = arith.constant 6.000000e+00 : f32
38-
// CHECK: %[[I8_EQUIV_D:.+]] = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
39+
// CHECK: %[[I8_EQUIV_D:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
3940
// CHECK: %[[OFFSET_D:.+]] = arith.constant 20 : index
40-
// Do not lower to a memref of memref
41-
// CHECK: %[[VIEW_D:.+]] = memref.view %3[%c20][] : memref<40xi8> to memref<5xf32>
41+
// CHECK: %[[VIEW_D:.+]] = memref.view %[[I8_EQUIV_D]][%[[OFFSET_D]]][] : memref<40xi8> to memref<5xf32>
4242
// CHECK: %[[C_4:.+]] = arith.constant 4 : i32
4343
// CHECK: %[[I_D:.+]] = arith.index_cast %[[C_4]] : i32 to index
4444
// CHECK: memref.store %[[C_6]], %[[VIEW_D]][%[[I_D]]] : memref<5xf32>
4545

4646
return v.c;
47-
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
47+
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
4848
// CHECK: %[[OFFSET_C_1:.+]] = arith.constant 16 : index
4949
// CHECK: %[[VIEW_C_1:.+]] = memref.view %[[I8_EQUIV_C_1]][%[[OFFSET_C_1]]][] : memref<40xi8> to memref<i8>
5050
// CHECK: %[[VALUE_C:.+]] = memref.load %[[VIEW_C_1]][] : memref<i8>

0 commit comments

Comments
 (0)