Skip to content

Commit 08be30d

Browse files
committed
[CIR][Lowering][MLIR] Rework the !cir.array lowering
Split completely the !cir.array lowering, like in struct/class/union, from any reference with memref construction. Rationalize the approach inside convertToReferenceType() instead of ad-hoc cases all-over the place. Fix a test which seems to have been wrong from the beginning.
1 parent ab35df7 commit 08be30d

File tree

1 file changed

+58
-64
lines changed

1 file changed

+58
-64
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 58 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,55 @@ using namespace llvm;
7373

7474
namespace cir {
7575

76+
/// Given a type convertor and a data layout, convert the given type to a type
77+
/// that is suitable for memory operations. For example, this can be used to
78+
/// lower cir.bool accesses to i8.
79+
static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter,
80+
mlir::Type type) {
81+
// TODO(cir): Handle other types similarly to clang's codegen
82+
// convertTypeForMemory
83+
if (isa<cir::BoolType>(type)) {
84+
// TODO: Use datalayout to get the size of bool
85+
return mlir::IntegerType::get(type.getContext(), 8);
86+
}
87+
88+
return converter.convertType(type);
89+
}
90+
91+
// Create a reference to an MLIR type. This creates a memref of the element type
92+
// with the requested shape except when it is a tensor because it represents a
93+
// !cir.array which has to be blessed as a memref of the tensor element type
94+
// instead.
95+
static mlir::MemRefType convertToReferenceType(ArrayRef<int64_t> shape,
96+
mlir::Type elementType) {
97+
if (auto t = mlir::dyn_cast<mlir::TensorType>(elementType))
98+
return mlir::MemRefType::get(t.getShape(), t.getElementType());
99+
return mlir::MemRefType::get(shape, elementType);
100+
}
101+
102+
// Lower a cir.array either as a memref when it has a reference semantics or as
103+
// a tensor when it has a value semantics (like inside a struct or union).
104+
mlir::Type lowerArrayType(cir::ArrayType type, bool hasValueSemantics,
105+
mlir::TypeConverter &converter) {
106+
SmallVector<int64_t> shape;
107+
mlir::Type curType = type;
108+
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
109+
shape.push_back(arrayType.getSize());
110+
curType = arrayType.getEltType();
111+
}
112+
auto elementType = convertTypeForMemory(converter, curType);
113+
// FIXME: The element type might not be converted.
114+
if (!elementType)
115+
return nullptr;
116+
// Arrays in C/C++ have a value semantics when in a struct, so use
117+
// a tensor.
118+
// TODO: tensors cannot contain most built-in types like memref.
119+
if (hasValueSemantics)
120+
return mlir::RankedTensorType::get(shape, elementType);
121+
// Otherwise, go to a memref.
122+
return convertToReferenceType(shape, elementType);
123+
}
124+
76125
class CIRReturnLowering : public mlir::OpConversionPattern<cir::ReturnOp> {
77126
public:
78127
using OpConversionPattern<cir::ReturnOp>::OpConversionPattern;
@@ -125,21 +174,6 @@ class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> {
125174
}
126175
};
127176

128-
/// Given a type convertor and a data layout, convert the given type to a type
129-
/// that is suitable for memory operations. For example, this can be used to
130-
/// lower cir.bool accesses to i8.
131-
static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter,
132-
mlir::Type type) {
133-
// TODO(cir): Handle other types similarly to clang's codegen
134-
// convertTypeForMemory
135-
if (isa<cir::BoolType>(type)) {
136-
// TODO: Use datalayout to get the size of bool
137-
return mlir::IntegerType::get(type.getContext(), 8);
138-
}
139-
140-
return converter.convertType(type);
141-
}
142-
143177
/// Emits the value from memory as expected by its users. Should be called when
144178
/// the memory represetnation of a CIR type is not equal to its scalar
145179
/// representation.
@@ -188,14 +222,7 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
188222
if (!mlirType)
189223
return mlir::LogicalResult::failure();
190224

191-
auto memreftype = mlir::dyn_cast<mlir::MemRefType>(mlirType);
192-
if (memreftype && mlir::isa<cir::ArrayType>(adaptor.getAllocaType())) {
193-
// if the type is an array,
194-
// we don't need to wrap with memref.
195-
} else {
196-
memreftype = mlir::MemRefType::get({}, mlirType);
197-
}
198-
225+
auto memreftype = convertToReferenceType({}, mlirType);
199226
rewriter.replaceOpWithNewOp<mlir::memref::AllocaOp>(op, memreftype,
200227
op.getAlignmentAttr());
201228
return mlir::LogicalResult::success();
@@ -337,7 +364,7 @@ class CIRGetMemberOpLowering
337364
auto flattenedMemRef = mlir::MemRefType::get(
338365
linearizedSize, mlir::IntegerType::get(getContext(), 8));
339366
// Use a special cast because normal memref cast cannot do such an extreme
340-
// cast.
367+
// cast. Could be an UnrealizedCastOp instead?
341368
auto bytesMemRef = rewriter.create<mlir::named_tuple::CastOp>(
342369
op.getLoc(), mlir::TypeRange{flattenedMemRef},
343370
mlir::ValueRange{adaptor.getAddr()});
@@ -1061,7 +1088,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
10611088
return mlir::failure();
10621089
auto memrefType = dyn_cast<mlir::MemRefType>(convertedType);
10631090
if (!memrefType)
1064-
memrefType = mlir::MemRefType::get({}, convertedType);
1091+
memrefType = convertToReferenceType({}, convertedType);
10651092
// Add an optional alignment to the global memref.
10661093
mlir::IntegerAttr memrefAlignment =
10671094
op.getAlignment()
@@ -1474,27 +1501,6 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
14741501
cirDataLayout);
14751502
}
14761503

1477-
// Lower a cir.array either as a memref when it has a reference semantics or as
1478-
// a tensor when it has a value semantics (like inside a struct or union)
1479-
mlir::Type lowerArrayType(cir::ArrayType type, bool hasValueSemantics,
1480-
mlir::TypeConverter &converter) {
1481-
SmallVector<int64_t> shape;
1482-
mlir::Type curType = type;
1483-
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1484-
shape.push_back(arrayType.getSize());
1485-
curType = arrayType.getEltType();
1486-
}
1487-
auto elementType = convertTypeForMemory(converter, curType);
1488-
// FIXME: The element type might not be converted
1489-
if (!elementType)
1490-
return nullptr;
1491-
// Arrays in C/C++ have a reference semantics when not in a struct, so use
1492-
// a memref
1493-
if (hasValueSemantics)
1494-
return mlir::RankedTensorType::get(shape, elementType);
1495-
return mlir::MemRefType::get(shape, elementType);
1496-
}
1497-
14981504
mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14991505
mlir::TypeConverter converter;
15001506
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
@@ -1503,12 +1509,8 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
15031509
// converted (e.g. struct)
15041510
if (!ty)
15051511
return nullptr;
1506-
if (isa<cir::ArrayType>(type.getPointee()))
1507-
// An array is already lowered as a memref with reference semantics by
1508-
// default
1509-
return ty;
15101512
// Each level of pointer becomes a level of memref
1511-
return mlir::MemRefType::get({}, ty);
1513+
return convertToReferenceType({}, ty);
15121514
});
15131515
converter.addConversion(
15141516
[&](mlir::IntegerType type) -> mlir::Type { return type; });
@@ -1546,23 +1548,15 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
15461548
return mlir::BFloat16Type::get(type.getContext());
15471549
});
15481550
converter.addConversion([&](cir::ArrayType type) -> mlir::Type {
1549-
// Assume we are not in a class/struct/union context with value semantics,
1550-
// so lower it as a memref to provide reference semantics.
1551-
return lowerArrayType(type, /* hasValueSemantics */ false, converter);
1551+
// Assume we are in a class/struct/union context with value semantics,
1552+
// so lower it as a tensor to provide value semantics.
1553+
return lowerArrayType(type, /* hasValueSemantics */ true, converter);
15521554
});
15531555
converter.addConversion([&](cir::VectorType type) -> mlir::Type {
15541556
auto ty = converter.convertType(type.getEltType());
15551557
return mlir::VectorType::get(type.getSize(), ty);
15561558
});
15571559
converter.addConversion([&](cir::StructType type) -> mlir::Type {
1558-
auto convertWithValueSemanticsArray = [&](mlir::Type t) {
1559-
if (mlir::isa<cir::ArrayType>(t))
1560-
// Inside a class/struct/union, an array has value semantics and is
1561-
// lowered as a tensor.
1562-
return lowerArrayType(mlir::cast<cir::ArrayType>(t),
1563-
/* hasValueSemantics */ true, converter);
1564-
return converter.convertType(t);
1565-
};
15661560
// FIXME(cir): create separate unions, struct, and classes types.
15671561
// Convert struct members.
15681562
llvm::SmallVector<mlir::Type> mlirMembers;
@@ -1571,13 +1565,13 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
15711565
// TODO(cir): This should be properly validated.
15721566
case cir::StructType::Struct:
15731567
for (auto ty : type.getMembers())
1574-
mlirMembers.push_back(convertWithValueSemanticsArray(ty));
1568+
mlirMembers.push_back(converter.convertType(ty));
15751569
break;
15761570
// Unions are lowered as only the largest member.
15771571
case cir::StructType::Union: {
15781572
auto largestMember = type.getLargestMember(dataLayout);
15791573
if (largestMember)
1580-
mlirMembers.push_back(convertWithValueSemanticsArray(largestMember));
1574+
mlirMembers.push_back(converter.convertType(largestMember));
15811575
break;
15821576
}
15831577
}

0 commit comments

Comments
 (0)