@@ -31,6 +31,18 @@ class CIR_ConfinedType<Type type, list<Pred> preds, string summary = "">
3131 : Type<And<[type.predicate, CIR_CastedSelfsToType<type.cppType, preds>]>,
3232 summary, type.cppType>;
3333
34+ // Generates a type summary.
35+ // - For a single type: returns its summary.
36+ // - For multiple types: returns `any of <comma-separated summaries>`.
37+ class CIR_TypeSummaries<list<Type> types> {
38+ assert !not(!empty(types)), "expects non-empty list of types";
39+
40+ list<string> summaries = !foreach(type, types, type.summary);
41+ string joined = !interleave(summaries, ", ");
42+
43+ string value = !if(!eq(!size(types), 1), joined, "any of " # joined);
44+ }
45+
3446//===----------------------------------------------------------------------===//
3547// Bool Type predicates
3648//===----------------------------------------------------------------------===//
@@ -184,6 +196,8 @@ def CIR_PtrToVoidPtrType
184196// Vector Type predicates
185197//===----------------------------------------------------------------------===//
186198
199+ def CIR_AnyVectorType : CIR_TypeBase<"::cir::VectorType", "vector type">;
200+
187201// Vector of integral type
188202def IntegerVector : Type<
189203 And<[
@@ -211,4 +225,27 @@ def CIR_AnyScalarType : AnyTypeOf<CIR_ScalarTypes, "cir scalar type"> {
211225 let cppFunctionName = "isScalarType";
212226}
213227
228+ //===----------------------------------------------------------------------===//
229+ // Element type constraint bases
230+ //===----------------------------------------------------------------------===//
231+
232+ class CIR_ElementTypePred<Pred pred> : SubstLeaves<"$_self",
233+ "::mlir::cast<::cir::VectorType>($_self).getElementType()", pred>;
234+
235+ class CIR_VectorTypeOf<list<Type> types, string summary = "">
236+ : CIR_ConfinedType<CIR_AnyVectorType,
237+ [Or<!foreach(type, types, CIR_ElementTypePred<type.predicate>)>],
238+ !if(!empty(summary),
239+ "vector of " # CIR_TypeSummaries<types>.value,
240+ summary)>;
241+
242+ // Vector of type constraints
243+ def CIR_VectorOfFloatType : CIR_VectorTypeOf<[CIR_AnyFloatType]>;
244+
245+ def CIR_AnyFloatOrVecOfFloatType
246+ : AnyTypeOf<[CIR_AnyFloatType, CIR_VectorOfFloatType],
247+ "floating point or vector of floating point type"> {
248+ let cppFunctionName = "isFPOrVectorOfFPType";
249+ }
250+
214251#endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD
0 commit comments