@@ -31,6 +31,18 @@ class CIR_ConfinedType<Type type, list<Pred> preds, string summary = "">
3131 : Type<And<[type.predicate, CIR_CastedSelfsToType<type.cppType, preds>]>,
3232 summary, type.cppType>;
3333
34+ // Generates a type summary.
35+ // - For a single type: returns its summary.
36+ // - For multiple types: returns `any of <comma-separated summaries>`.
37+ class CIR_TypeSummaries<list<Type> types> {
38+ assert !not(!empty(types)), "expects non-empty list of types";
39+
40+ list<string> summaries = !foreach(type, types, type.summary);
41+ string joined = !interleave(summaries, ", ");
42+
43+ string value = !if(!eq(!size(types), 1), joined, "any of " # joined);
44+ }
45+
3446//===----------------------------------------------------------------------===//
3547// Bool Type predicates
3648//===----------------------------------------------------------------------===//
@@ -184,6 +196,24 @@ def CIR_PtrToVoidPtrType
184196// Vector Type predicates
185197//===----------------------------------------------------------------------===//
186198
199+ def CIR_AnyVectorType : CIR_TypeBase<"::cir::VectorType", "vector type">;
200+
201+ def CIR_VectorElementType : AnyTypeOf<[CIR_AnyIntOrFloatType, CIR_AnyPtrType],
202+ "any cir integer, floating point or pointer type"
203+ > {
204+ let cppFunctionName = "isValidVectorTypeElementType";
205+ }
206+
207+ class CIR_ElementTypePred<Pred pred> : SubstLeaves<"$_self",
208+ "::mlir::cast<::cir::VectorType>($_self).getElementType()", pred>;
209+
210+ class CIR_VectorTypeOf<list<Type> types, string summary = "">
211+ : CIR_ConfinedType<CIR_AnyVectorType,
212+ [Or<!foreach(type, types, CIR_ElementTypePred<type.predicate>)>],
213+ !if(!empty(summary),
214+ "vector of " # CIR_TypeSummaries<types>.value,
215+ summary)>;
216+
187217// Vector of integral type
188218def IntegerVector : Type<
189219 And<[
@@ -196,8 +226,36 @@ def IntegerVector : Type<
196226 ]>, "!cir.vector of !cir.int"> {
197227}
198228
199- // Any Integer or Vector of Integer Constraints
200- def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_AnyIntType, IntegerVector]>;
229+ // Vector of type constraints
230+ def CIR_VectorOfIntType : CIR_VectorTypeOf<[CIR_AnyIntType]>;
231+ def CIR_VectorOfUIntType : CIR_VectorTypeOf<[CIR_AnyUIntType]>;
232+ def CIR_VectorOfSIntType : CIR_VectorTypeOf<[CIR_AnySIntType]>;
233+ def CIR_VectorOfFloatType : CIR_VectorTypeOf<[CIR_AnyFloatType]>;
234+
235+ // Vector or Scalar type constraints
236+ def CIR_AnyIntOrVecOfIntType
237+ : AnyTypeOf<[CIR_AnyIntType, CIR_VectorOfIntType],
238+ "integer or vector of integer type"> {
239+ let cppFunctionName = "isIntOrVectorOfIntType";
240+ }
241+
242+ def CIR_AnySIntOrVecOfSIntType
243+ : AnyTypeOf<[CIR_AnySIntType, CIR_VectorOfSIntType],
244+ "signed integer or vector of signed integer type"> {
245+ let cppFunctionName = "isSIntOrVectorOfSIntType";
246+ }
247+
248+ def CIR_AnyUIntOrVecOfUIntType
249+ : AnyTypeOf<[CIR_AnyUIntType, CIR_VectorOfUIntType],
250+ "unsigned integer or vector of unsigned integer type"> {
251+ let cppFunctionName = "isUIntOrVectorOfUIntType";
252+ }
253+
254+ def CIR_AnyFloatOrVecOfFloatType
255+ : AnyTypeOf<[CIR_AnyFloatType, CIR_VectorOfFloatType],
256+ "floating point or vector of floating point type"> {
257+ let cppFunctionName = "isFPOrVectorOfFPType";
258+ }
201259
202260//===----------------------------------------------------------------------===//
203261// Scalar Type predicates
0 commit comments