Skip to content

Commit a207932

Browse files
committed
[CIR][Lowering][MLIR] Rework the !cir.array lowering
Split completely the !cir.array lowering, like in struct/class/union, from any reference with memref construction. Rationalize the approach inside convertToReferenceType() instead of ad-hoc cases all-over the place. Fix a test which seems to have been wrong from the beginning.
1 parent 6c35a8f commit a207932

File tree

1 file changed

+58
-64
lines changed

1 file changed

+58
-64
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 58 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,55 @@ using namespace llvm;
6969

7070
namespace cir {
7171

72+
/// Given a type convertor and a data layout, convert the given type to a type
73+
/// that is suitable for memory operations. For example, this can be used to
74+
/// lower cir.bool accesses to i8.
75+
static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter,
76+
mlir::Type type) {
77+
// TODO(cir): Handle other types similarly to clang's codegen
78+
// convertTypeForMemory
79+
if (isa<cir::BoolType>(type)) {
80+
// TODO: Use datalayout to get the size of bool
81+
return mlir::IntegerType::get(type.getContext(), 8);
82+
}
83+
84+
return converter.convertType(type);
85+
}
86+
87+
// Create a reference to an MLIR type. This creates a memref of the element type
88+
// with the requested shape except when it is a tensor because it represents a
89+
// !cir.array which has to be blessed as a memref of the tensor element type
90+
// instead.
91+
static mlir::MemRefType convertToReferenceType(ArrayRef<int64_t> shape,
92+
mlir::Type elementType) {
93+
if (auto t = mlir::dyn_cast<mlir::TensorType>(elementType))
94+
return mlir::MemRefType::get(t.getShape(), t.getElementType());
95+
return mlir::MemRefType::get(shape, elementType);
96+
}
97+
98+
// Lower a cir.array either as a memref when it has a reference semantics or as
99+
// a tensor when it has a value semantics (like inside a struct or union).
100+
mlir::Type lowerArrayType(cir::ArrayType type, bool hasValueSemantics,
101+
mlir::TypeConverter &converter) {
102+
SmallVector<int64_t> shape;
103+
mlir::Type curType = type;
104+
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
105+
shape.push_back(arrayType.getSize());
106+
curType = arrayType.getEltType();
107+
}
108+
auto elementType = convertTypeForMemory(converter, curType);
109+
// FIXME: The element type might not be converted.
110+
if (!elementType)
111+
return nullptr;
112+
// Arrays in C/C++ have a value semantics when in a struct, so use
113+
// a tensor.
114+
// TODO: tensors cannot contain most built-in types like memref.
115+
if (hasValueSemantics)
116+
return mlir::RankedTensorType::get(shape, elementType);
117+
// Otherwise, go to a memref.
118+
return convertToReferenceType(shape, elementType);
119+
}
120+
72121
class CIRReturnLowering : public mlir::OpConversionPattern<cir::ReturnOp> {
73122
public:
74123
using OpConversionPattern<cir::ReturnOp>::OpConversionPattern;
@@ -121,21 +170,6 @@ class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> {
121170
}
122171
};
123172

124-
/// Given a type convertor and a data layout, convert the given type to a type
125-
/// that is suitable for memory operations. For example, this can be used to
126-
/// lower cir.bool accesses to i8.
127-
static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter,
128-
mlir::Type type) {
129-
// TODO(cir): Handle other types similarly to clang's codegen
130-
// convertTypeForMemory
131-
if (isa<cir::BoolType>(type)) {
132-
// TODO: Use datalayout to get the size of bool
133-
return mlir::IntegerType::get(type.getContext(), 8);
134-
}
135-
136-
return converter.convertType(type);
137-
}
138-
139173
/// Emits the value from memory as expected by its users. Should be called when
140174
/// the memory represetnation of a CIR type is not equal to its scalar
141175
/// representation.
@@ -184,14 +218,7 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
184218
if (!mlirType)
185219
return mlir::LogicalResult::failure();
186220

187-
auto memreftype = mlir::dyn_cast<mlir::MemRefType>(mlirType);
188-
if (memreftype && mlir::isa<cir::ArrayType>(adaptor.getAllocaType())) {
189-
// if the type is an array,
190-
// we don't need to wrap with memref.
191-
} else {
192-
memreftype = mlir::MemRefType::get({}, mlirType);
193-
}
194-
221+
auto memreftype = convertToReferenceType({}, mlirType);
195222
rewriter.replaceOpWithNewOp<mlir::memref::AllocaOp>(op, memreftype,
196223
op.getAlignmentAttr());
197224
return mlir::LogicalResult::success();
@@ -333,7 +360,7 @@ class CIRGetMemberOpLowering
333360
auto flattenedMemRef = mlir::MemRefType::get(
334361
linearizedSize, mlir::IntegerType::get(getContext(), 8));
335362
// Use a special cast because normal memref cast cannot do such an extreme
336-
// cast.
363+
// cast. Could be an UnrealizedCastOp instead?
337364
auto bytesMemRef = rewriter.create<mlir::named_tuple::CastOp>(
338365
op.getLoc(), mlir::TypeRange{flattenedMemRef},
339366
mlir::ValueRange{adaptor.getAddr()});
@@ -949,7 +976,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
949976
return mlir::failure();
950977
auto memrefType = dyn_cast<mlir::MemRefType>(convertedType);
951978
if (!memrefType)
952-
memrefType = mlir::MemRefType::get({}, convertedType);
979+
memrefType = convertToReferenceType({}, convertedType);
953980
// Add an optional alignment to the global memref.
954981
mlir::IntegerAttr memrefAlignment =
955982
op.getAlignment()
@@ -1394,27 +1421,6 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
13941421
cirDataLayout);
13951422
}
13961423

1397-
// Lower a cir.array either as a memref when it has a reference semantics or as
1398-
// a tensor when it has a value semantics (like inside a struct or union)
1399-
mlir::Type lowerArrayType(cir::ArrayType type, bool hasValueSemantics,
1400-
mlir::TypeConverter &converter) {
1401-
SmallVector<int64_t> shape;
1402-
mlir::Type curType = type;
1403-
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1404-
shape.push_back(arrayType.getSize());
1405-
curType = arrayType.getEltType();
1406-
}
1407-
auto elementType = convertTypeForMemory(converter, curType);
1408-
// FIXME: The element type might not be converted
1409-
if (!elementType)
1410-
return nullptr;
1411-
// Arrays in C/C++ have a reference semantics when not in a struct, so use
1412-
// a memref
1413-
if (hasValueSemantics)
1414-
return mlir::RankedTensorType::get(shape, elementType);
1415-
return mlir::MemRefType::get(shape, elementType);
1416-
}
1417-
14181424
mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14191425
mlir::TypeConverter converter;
14201426
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
@@ -1423,12 +1429,8 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14231429
// converted (e.g. struct)
14241430
if (!ty)
14251431
return nullptr;
1426-
if (isa<cir::ArrayType>(type.getPointee()))
1427-
// An array is already lowered as a memref with reference semantics by
1428-
// default
1429-
return ty;
14301432
// Each level of pointer becomes a level of memref
1431-
return mlir::MemRefType::get({}, ty);
1433+
return convertToReferenceType({}, ty);
14321434
});
14331435
converter.addConversion(
14341436
[&](mlir::IntegerType type) -> mlir::Type { return type; });
@@ -1466,23 +1468,15 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14661468
return mlir::BFloat16Type::get(type.getContext());
14671469
});
14681470
converter.addConversion([&](cir::ArrayType type) -> mlir::Type {
1469-
// Assume we are not in a class/struct/union context with value semantics,
1470-
// so lower it as a memref to provide reference semantics.
1471-
return lowerArrayType(type, /* hasValueSemantics */ false, converter);
1471+
// Assume we are in a class/struct/union context with value semantics,
1472+
// so lower it as a tensor to provide value semantics.
1473+
return lowerArrayType(type, /* hasValueSemantics */ true, converter);
14721474
});
14731475
converter.addConversion([&](cir::VectorType type) -> mlir::Type {
14741476
auto ty = converter.convertType(type.getEltType());
14751477
return mlir::VectorType::get(type.getSize(), ty);
14761478
});
14771479
converter.addConversion([&](cir::StructType type) -> mlir::Type {
1478-
auto convertWithValueSemanticsArray = [&](mlir::Type t) {
1479-
if (mlir::isa<cir::ArrayType>(t))
1480-
// Inside a class/struct/union, an array has value semantics and is
1481-
// lowered as a tensor.
1482-
return lowerArrayType(mlir::cast<cir::ArrayType>(t),
1483-
/* hasValueSemantics */ true, converter);
1484-
return converter.convertType(t);
1485-
};
14861480
// FIXME(cir): create separate unions, struct, and classes types.
14871481
// Convert struct members.
14881482
llvm::SmallVector<mlir::Type> mlirMembers;
@@ -1491,13 +1485,13 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14911485
// TODO(cir): This should be properly validated.
14921486
case cir::StructType::Struct:
14931487
for (auto ty : type.getMembers())
1494-
mlirMembers.push_back(convertWithValueSemanticsArray(ty));
1488+
mlirMembers.push_back(converter.convertType(ty));
14951489
break;
14961490
// Unions are lowered as only the largest member.
14971491
case cir::StructType::Union: {
14981492
auto largestMember = type.getLargestMember(dataLayout);
14991493
if (largestMember)
1500-
mlirMembers.push_back(convertWithValueSemanticsArray(largestMember));
1494+
mlirMembers.push_back(converter.convertType(largestMember));
15011495
break;
15021496
}
15031497
}

0 commit comments

Comments
 (0)