@@ -31,6 +31,18 @@ class CIR_ConfinedType<Type type, list<Pred> preds, string summary = "">
3131    : Type<And<[type.predicate, CIR_CastedSelfsToType<type.cppType, preds>]>,
3232         summary, type.cppType>;
3333
34+ // Generates a type summary.
35+ // - For a single type: returns its summary.
36+ // - For multiple types: returns `any of <comma-separated summaries>`.
37+ class CIR_TypeSummaries<list<Type> types> {
38+     assert !not(!empty(types)), "expects non-empty list of types";
39+ 
40+     list<string> summaries = !foreach(type, types, type.summary);
41+     string joined = !interleave(summaries, ", ");
42+ 
43+     string value = !if(!eq(!size(types), 1), joined, "any of " # joined);
44+ }
45+ 
3446//===----------------------------------------------------------------------===//
3547// Bool Type predicates
3648//===----------------------------------------------------------------------===//
@@ -184,6 +196,24 @@ def CIR_PtrToVoidPtrType
184196// Vector Type predicates
185197//===----------------------------------------------------------------------===//
186198
199+ def CIR_AnyVectorType : CIR_TypeBase<"::cir::VectorType", "vector type">;
200+ 
201+ def CIR_VectorElementType : AnyTypeOf<[CIR_AnyIntOrFloatType, CIR_AnyPtrType],
202+     "any cir integer, floating point or pointer type"
203+ > {
204+     let cppFunctionName = "isValidVectorTypeElementType";
205+ }
206+ 
207+ class CIR_ElementTypePred<Pred pred> : SubstLeaves<"$_self",
208+     "::mlir::cast<::cir::VectorType>($_self).getElementType()", pred>;
209+ 
210+ class CIR_VectorTypeOf<list<Type> types, string summary = "">
211+     : CIR_ConfinedType<CIR_AnyVectorType,
212+         [Or<!foreach(type, types, CIR_ElementTypePred<type.predicate>)>],
213+         !if(!empty(summary),
214+             "vector of " # CIR_TypeSummaries<types>.value,
215+             summary)>;
216+ 
187217// Vector of integral type
188218def IntegerVector : Type<
189219    And<[
@@ -196,8 +226,36 @@ def IntegerVector : Type<
196226    ]>, "!cir.vector of !cir.int"> {
197227}
198228
199- // Any Integer or Vector of Integer Constraints
200- def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_AnyIntType, IntegerVector]>;
229+ // Vector of type constraints
230+ def CIR_VectorOfIntType : CIR_VectorTypeOf<[CIR_AnyIntType]>;
231+ def CIR_VectorOfUIntType : CIR_VectorTypeOf<[CIR_AnyUIntType]>;
232+ def CIR_VectorOfSIntType : CIR_VectorTypeOf<[CIR_AnySIntType]>;
233+ def CIR_VectorOfFloatType : CIR_VectorTypeOf<[CIR_AnyFloatType]>;
234+ 
235+ // Vector or Scalar type constraints
236+ def CIR_AnyIntOrVecOfIntType
237+     : AnyTypeOf<[CIR_AnyIntType, CIR_VectorOfIntType],
238+         "integer or vector of integer type"> {
239+     let cppFunctionName = "isIntOrVectorOfIntType";
240+ }
241+ 
242+ def CIR_AnySIntOrVecOfSIntType
243+     : AnyTypeOf<[CIR_AnySIntType, CIR_VectorOfSIntType],
244+         "signed integer or vector of signed integer type"> {
245+     let cppFunctionName = "isSIntOrVectorOfSIntType";
246+ }
247+ 
248+ def CIR_AnyUIntOrVecOfUIntType
249+     : AnyTypeOf<[CIR_AnyUIntType, CIR_VectorOfUIntType],
250+         "unsigned integer or vector of unsigned integer type"> {
251+     let cppFunctionName = "isUIntOrVectorOfUIntType";
252+ }
253+ 
254+ def CIR_AnyFloatOrVecOfFloatType
255+     : AnyTypeOf<[CIR_AnyFloatType, CIR_VectorOfFloatType],
256+         "floating point or vector of floating point type"> {
257+     let cppFunctionName = "isFPOrVectorOfFPType";
258+ }
201259
202260//===----------------------------------------------------------------------===//
203261// Scalar Type predicates
0 commit comments