Skip to content

Commit af16fa3

Browse files
committed
[CIR][Lowering][MLIR] Rework the !cir.array lowering
Split completely the !cir.array lowering, like in struct/class/union, from any reference with memref construction. Rationalize the approach inside convertToReferenceType() instead of ad-hoc cases all-over the place. Fix a test which seems to have been wrong from the beginning.
1 parent 0a66217 commit af16fa3

File tree

1 file changed

+58
-64
lines changed

1 file changed

+58
-64
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 58 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,55 @@ using namespace llvm;
7272

7373
namespace cir {
7474

75+
/// Given a type convertor and a data layout, convert the given type to a type
76+
/// that is suitable for memory operations. For example, this can be used to
77+
/// lower cir.bool accesses to i8.
78+
static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter,
79+
mlir::Type type) {
80+
// TODO(cir): Handle other types similarly to clang's codegen
81+
// convertTypeForMemory
82+
if (isa<cir::BoolType>(type)) {
83+
// TODO: Use datalayout to get the size of bool
84+
return mlir::IntegerType::get(type.getContext(), 8);
85+
}
86+
87+
return converter.convertType(type);
88+
}
89+
90+
// Create a reference to an MLIR type. This creates a memref of the element type
91+
// with the requested shape except when it is a tensor because it represents a
92+
// !cir.array which has to be blessed as a memref of the tensor element type
93+
// instead.
94+
static mlir::MemRefType convertToReferenceType(ArrayRef<int64_t> shape,
95+
mlir::Type elementType) {
96+
if (auto t = mlir::dyn_cast<mlir::TensorType>(elementType))
97+
return mlir::MemRefType::get(t.getShape(), t.getElementType());
98+
return mlir::MemRefType::get(shape, elementType);
99+
}
100+
101+
// Lower a cir.array either as a memref when it has a reference semantics or as
102+
// a tensor when it has a value semantics (like inside a struct or union).
103+
mlir::Type lowerArrayType(cir::ArrayType type, bool hasValueSemantics,
104+
mlir::TypeConverter &converter) {
105+
SmallVector<int64_t> shape;
106+
mlir::Type curType = type;
107+
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
108+
shape.push_back(arrayType.getSize());
109+
curType = arrayType.getEltType();
110+
}
111+
auto elementType = convertTypeForMemory(converter, curType);
112+
// FIXME: The element type might not be converted.
113+
if (!elementType)
114+
return nullptr;
115+
// Arrays in C/C++ have a value semantics when in a struct, so use
116+
// a tensor.
117+
// TODO: tensors cannot contain most built-in types like memref.
118+
if (hasValueSemantics)
119+
return mlir::RankedTensorType::get(shape, elementType);
120+
// Otherwise, go to a memref.
121+
return convertToReferenceType(shape, elementType);
122+
}
123+
75124
class CIRReturnLowering : public mlir::OpConversionPattern<cir::ReturnOp> {
76125
public:
77126
using OpConversionPattern<cir::ReturnOp>::OpConversionPattern;
@@ -124,21 +173,6 @@ class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> {
124173
}
125174
};
126175

127-
/// Given a type convertor and a data layout, convert the given type to a type
128-
/// that is suitable for memory operations. For example, this can be used to
129-
/// lower cir.bool accesses to i8.
130-
static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter,
131-
mlir::Type type) {
132-
// TODO(cir): Handle other types similarly to clang's codegen
133-
// convertTypeForMemory
134-
if (isa<cir::BoolType>(type)) {
135-
// TODO: Use datalayout to get the size of bool
136-
return mlir::IntegerType::get(type.getContext(), 8);
137-
}
138-
139-
return converter.convertType(type);
140-
}
141-
142176
/// Emits the value from memory as expected by its users. Should be called when
143177
/// the memory represetnation of a CIR type is not equal to its scalar
144178
/// representation.
@@ -187,14 +221,7 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
187221
if (!mlirType)
188222
return mlir::LogicalResult::failure();
189223

190-
auto memreftype = mlir::dyn_cast<mlir::MemRefType>(mlirType);
191-
if (memreftype && mlir::isa<cir::ArrayType>(adaptor.getAllocaType())) {
192-
// if the type is an array,
193-
// we don't need to wrap with memref.
194-
} else {
195-
memreftype = mlir::MemRefType::get({}, mlirType);
196-
}
197-
224+
auto memreftype = convertToReferenceType({}, mlirType);
198225
rewriter.replaceOpWithNewOp<mlir::memref::AllocaOp>(op, memreftype,
199226
op.getAlignmentAttr());
200227
return mlir::LogicalResult::success();
@@ -336,7 +363,7 @@ class CIRGetMemberOpLowering
336363
auto flattenedMemRef = mlir::MemRefType::get(
337364
linearizedSize, mlir::IntegerType::get(getContext(), 8));
338365
// Use a special cast because normal memref cast cannot do such an extreme
339-
// cast.
366+
// cast. Could be an UnrealizedCastOp instead?
340367
auto bytesMemRef = rewriter.create<mlir::named_tuple::CastOp>(
341368
op.getLoc(), mlir::TypeRange{flattenedMemRef},
342369
mlir::ValueRange{adaptor.getAddr()});
@@ -1062,7 +1089,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
10621089
return mlir::failure();
10631090
auto memrefType = dyn_cast<mlir::MemRefType>(convertedType);
10641091
if (!memrefType)
1065-
memrefType = mlir::MemRefType::get({}, convertedType);
1092+
memrefType = convertToReferenceType({}, convertedType);
10661093
// Add an optional alignment to the global memref.
10671094
mlir::IntegerAttr memrefAlignment =
10681095
op.getAlignment()
@@ -1475,27 +1502,6 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
14751502
cirDataLayout);
14761503
}
14771504

1478-
// Lower a cir.array either as a memref when it has a reference semantics or as
1479-
// a tensor when it has a value semantics (like inside a struct or union)
1480-
mlir::Type lowerArrayType(cir::ArrayType type, bool hasValueSemantics,
1481-
mlir::TypeConverter &converter) {
1482-
SmallVector<int64_t> shape;
1483-
mlir::Type curType = type;
1484-
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1485-
shape.push_back(arrayType.getSize());
1486-
curType = arrayType.getEltType();
1487-
}
1488-
auto elementType = convertTypeForMemory(converter, curType);
1489-
// FIXME: The element type might not be converted
1490-
if (!elementType)
1491-
return nullptr;
1492-
// Arrays in C/C++ have a reference semantics when not in a struct, so use
1493-
// a memref
1494-
if (hasValueSemantics)
1495-
return mlir::RankedTensorType::get(shape, elementType);
1496-
return mlir::MemRefType::get(shape, elementType);
1497-
}
1498-
14991505
mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
15001506
mlir::TypeConverter converter;
15011507
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
@@ -1504,12 +1510,8 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
15041510
// converted (e.g. struct)
15051511
if (!ty)
15061512
return nullptr;
1507-
if (isa<cir::ArrayType>(type.getPointee()))
1508-
// An array is already lowered as a memref with reference semantics by
1509-
// default
1510-
return ty;
15111513
// Each level of pointer becomes a level of memref
1512-
return mlir::MemRefType::get({}, ty);
1514+
return convertToReferenceType({}, ty);
15131515
});
15141516
converter.addConversion(
15151517
[&](mlir::IntegerType type) -> mlir::Type { return type; });
@@ -1547,23 +1549,15 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
15471549
return mlir::BFloat16Type::get(type.getContext());
15481550
});
15491551
converter.addConversion([&](cir::ArrayType type) -> mlir::Type {
1550-
// Assume we are not in a class/struct/union context with value semantics,
1551-
// so lower it as a memref to provide reference semantics.
1552-
return lowerArrayType(type, /* hasValueSemantics */ false, converter);
1552+
// Assume we are in a class/struct/union context with value semantics,
1553+
// so lower it as a tensor to provide value semantics.
1554+
return lowerArrayType(type, /* hasValueSemantics */ true, converter);
15531555
});
15541556
converter.addConversion([&](cir::VectorType type) -> mlir::Type {
15551557
auto ty = converter.convertType(type.getEltType());
15561558
return mlir::VectorType::get(type.getSize(), ty);
15571559
});
15581560
converter.addConversion([&](cir::StructType type) -> mlir::Type {
1559-
auto convertWithValueSemanticsArray = [&](mlir::Type t) {
1560-
if (mlir::isa<cir::ArrayType>(t))
1561-
// Inside a class/struct/union, an array has value semantics and is
1562-
// lowered as a tensor.
1563-
return lowerArrayType(mlir::cast<cir::ArrayType>(t),
1564-
/* hasValueSemantics */ true, converter);
1565-
return converter.convertType(t);
1566-
};
15671561
// FIXME(cir): create separate unions, struct, and classes types.
15681562
// Convert struct members.
15691563
llvm::SmallVector<mlir::Type> mlirMembers;
@@ -1572,13 +1566,13 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
15721566
// TODO(cir): This should be properly validated.
15731567
case cir::StructType::Struct:
15741568
for (auto ty : type.getMembers())
1575-
mlirMembers.push_back(convertWithValueSemanticsArray(ty));
1569+
mlirMembers.push_back(converter.convertType(ty));
15761570
break;
15771571
// Unions are lowered as only the largest member.
15781572
case cir::StructType::Union: {
15791573
auto largestMember = type.getLargestMember(dataLayout);
15801574
if (largestMember)
1581-
mlirMembers.push_back(convertWithValueSemanticsArray(largestMember));
1575+
mlirMembers.push_back(converter.convertType(largestMember));
15821576
break;
15831577
}
15841578
}

0 commit comments

Comments
 (0)