Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1464,9 +1464,13 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
```
}];

let results = (outs CIR_AnyIntOrVecOfInt:$result);
let arguments = (ins CIR_AnyIntOrVecOfInt:$value, CIR_AnyIntOrVecOfInt:$amount,
UnitAttr:$isShiftleft);
let arguments = (ins
CIR_AnyIntOrVecOfIntType:$value,
CIR_AnyIntOrVecOfIntType:$amount,
UnitAttr:$isShiftleft
);

let results = (outs CIR_AnyIntOrVecOfIntType:$result);

let assemblyFormat = [{
`(`
Expand Down Expand Up @@ -2050,7 +2054,7 @@ def VecCreateOp : CIR_Op<"vec.create", [Pure]> {
in the vector type.
}];

let arguments = (ins Variadic<CIR_AnyType>:$elements);
let arguments = (ins Variadic<CIR_VectorElementType>:$elements);
let results = (outs CIR_VectorType:$result);

let assemblyFormat = [{
Expand Down Expand Up @@ -2085,7 +2089,7 @@ def VecInsertOp : CIR_Op<"vec.insert", [Pure,

let arguments = (ins
CIR_VectorType:$vec,
AnyType:$value,
CIR_VectorElementType:$value,
CIR_AnyFundamentalIntType:$index
);

Expand Down Expand Up @@ -2118,7 +2122,7 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
}];

let arguments = (ins CIR_VectorType:$vec, CIR_AnyFundamentalIntType:$index);
let results = (outs CIR_AnyType:$result);
let results = (outs CIR_VectorElementType:$result);

let assemblyFormat = [{
$vec `[` $index `:` type($index) `]` attr-dict `:` qualified(type($vec))
Expand Down Expand Up @@ -2180,7 +2184,7 @@ def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic",
```
}];

let arguments = (ins CIR_VectorType:$vec, IntegerVector:$indices);
let arguments = (ins CIR_VectorType:$vec, CIR_VectorOfIntType:$indices);
let results = (outs CIR_VectorType:$result);
let assemblyFormat = [{
$vec `:` qualified(type($vec)) `,` $indices `:` qualified(type($indices))
Expand Down
62 changes: 60 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ class CIR_ConfinedType<Type type, list<Pred> preds, string summary = "">
: Type<And<[type.predicate, CIR_CastedSelfsToType<type.cppType, preds>]>,
summary, type.cppType>;

// Generates a type summary.
// - For a single type: returns its summary.
// - For multiple types: returns `any of <comma-separated summaries>`.
class CIR_TypeSummaries<list<Type> types> {
assert !not(!empty(types)), "expects non-empty list of types";

list<string> summaries = !foreach(type, types, type.summary);
string joined = !interleave(summaries, ", ");

string value = !if(!eq(!size(types), 1), joined, "any of " # joined);
}

//===----------------------------------------------------------------------===//
// Bool Type predicates
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -184,6 +196,24 @@ def CIR_PtrToVoidPtrType
// Vector Type predicates
//===----------------------------------------------------------------------===//

def CIR_AnyVectorType : CIR_TypeBase<"::cir::VectorType", "vector type">;

def CIR_VectorElementType : AnyTypeOf<[CIR_AnyIntOrFloatType, CIR_AnyPtrType],
"any cir integer, floating point or pointer type"
> {
let cppFunctionName = "isValidVectorTypeElementType";
}

class CIR_ElementTypePred<Pred pred> : SubstLeaves<"$_self",
"::mlir::cast<::cir::VectorType>($_self).getElementType()", pred>;

class CIR_VectorTypeOf<list<Type> types, string summary = "">
: CIR_ConfinedType<CIR_AnyVectorType,
[Or<!foreach(type, types, CIR_ElementTypePred<type.predicate>)>],
!if(!empty(summary),
"vector of " # CIR_TypeSummaries<types>.value,
summary)>;

// Vector of integral type
def IntegerVector : Type<
And<[
Expand All @@ -196,8 +226,36 @@ def IntegerVector : Type<
]>, "!cir.vector of !cir.int"> {
}

// Any Integer or Vector of Integer Constraints
def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_AnyIntType, IntegerVector]>;
// Vector of type constraints
def CIR_VectorOfIntType : CIR_VectorTypeOf<[CIR_AnyIntType]>;
def CIR_VectorOfUIntType : CIR_VectorTypeOf<[CIR_AnyUIntType]>;
def CIR_VectorOfSIntType : CIR_VectorTypeOf<[CIR_AnySIntType]>;
def CIR_VectorOfFloatType : CIR_VectorTypeOf<[CIR_AnyFloatType]>;

// Vector or Scalar type constraints
def CIR_AnyIntOrVecOfIntType
: AnyTypeOf<[CIR_AnyIntType, CIR_VectorOfIntType],
"integer or vector of integer type"> {
let cppFunctionName = "isIntOrVectorOfIntType";
}

def CIR_AnySIntOrVecOfSIntType
: AnyTypeOf<[CIR_AnySIntType, CIR_VectorOfSIntType],
"signed integer or vector of signed integer type"> {
let cppFunctionName = "isSIntOrVectorOfSIntType";
}

def CIR_AnyUIntOrVecOfUIntType
: AnyTypeOf<[CIR_AnyUIntType, CIR_VectorOfUIntType],
"unsigned integer or vector of unsigned integer type"> {
let cppFunctionName = "isUIntOrVectorOfUIntType";
}

def CIR_AnyFloatOrVecOfFloatType
: AnyTypeOf<[CIR_AnyFloatType, CIR_VectorOfFloatType],
"floating point or vector of floating point type"> {
let cppFunctionName = "isFPOrVectorOfFPType";
}

//===----------------------------------------------------------------------===//
// Scalar Type predicates
Expand Down
2 changes: 0 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ struct RecordTypeStorage;

bool isValidFundamentalIntWidth(unsigned width);

bool isFPOrFPVectorTy(mlir::Type);

} // namespace cir

//===----------------------------------------------------------------------===//
Expand Down
25 changes: 19 additions & 6 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -275,18 +275,31 @@ def CIR_VectorType : CIR_Type<"Vector", "vector",

let summary = "CIR vector type";
let description = [{
`!cir.vector' represents fixed-size vector types, parameterized
by the element type and the number of elements.
The `!cir.vector` type represents a fixed-size, one-dimensional vector.
It takes two parameters: the element type and the number of elements.

Example:
Syntax:

```mlir
!cir.vector<!u64i x 2>
!cir.vector<!cir.float x 4>
vector-type ::= !cir.vector<size x element-type>
element-type ::= float-type | integer-type | pointer-type
```

The `element-type` must be a scalar CIR type. Zero-sized vectors are not
allowed. The `size` must be a positive integer.

Examples:

```mlir
!cir.vector<4 x !cir.int<u, 8>>
!cir.vector<2 x !cir.float>
```
}];

let parameters = (ins "mlir::Type":$elementType, "uint64_t":$size);
let parameters = (ins
CIR_VectorElementType:$elementType,
"uint64_t":$size
);

let assemblyFormat = [{
`<` $size `x` $elementType `>`
Expand Down
6 changes: 3 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1311,7 +1311,7 @@ mlir::Value ScalarExprEmitter::emitMul(const BinOpInfo &ops) {
!canElideOverflowCheck(cgf.getContext(), ops))
cgf.cgm.errorNYI("unsigned int overflow sanitizer");

if (cir::isFPOrFPVectorTy(ops.lhs.getType())) {
if (cir::isFPOrVectorOfFPType(ops.lhs.getType())) {
assert(!cir::MissingFeatures::cgFPOptionsRAII());
return builder.createFMul(loc, ops.lhs, ops.rhs);
}
Expand Down Expand Up @@ -1370,7 +1370,7 @@ mlir::Value ScalarExprEmitter::emitAdd(const BinOpInfo &ops) {
!canElideOverflowCheck(cgf.getContext(), ops))
cgf.cgm.errorNYI("unsigned int overflow sanitizer");

if (cir::isFPOrFPVectorTy(ops.lhs.getType())) {
if (cir::isFPOrVectorOfFPType(ops.lhs.getType())) {
assert(!cir::MissingFeatures::cgFPOptionsRAII());
return builder.createFAdd(loc, ops.lhs, ops.rhs);
}
Expand Down Expand Up @@ -1418,7 +1418,7 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &ops) {
!canElideOverflowCheck(cgf.getContext(), ops))
cgf.cgm.errorNYI("unsigned int overflow sanitizer");

if (cir::isFPOrFPVectorTy(ops.lhs.getType())) {
if (cir::isFPOrVectorOfFPType(ops.lhs.getType())) {
assert(!cir::MissingFeatures::cgFPOptionsRAII());
return builder.createFSub(loc, ops.lhs, ops.rhs);
}
Expand Down
21 changes: 1 addition & 20 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,15 +552,6 @@ LongDoubleType::getABIAlignment(const mlir::DataLayout &dataLayout,
.getABIAlignment(dataLayout, params);
}

//===----------------------------------------------------------------------===//
// Floating-point and Float-point Vector type helpers
//===----------------------------------------------------------------------===//

bool cir::isFPOrFPVectorTy(mlir::Type t) {
assert(!cir::MissingFeatures::vectorType());
return isAnyFloatingPointType(t);
}

//===----------------------------------------------------------------------===//
// FuncType Definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -693,17 +684,7 @@ mlir::LogicalResult cir::VectorType::verify(
mlir::Type elementType, uint64_t size) {
if (size == 0)
return emitError() << "the number of vector elements must be non-zero";

// Check if it a valid FixedVectorType
if (mlir::isa<cir::PointerType, cir::FP128Type>(elementType))
return success();

// Check if it a valid VectorType
if (mlir::isa<cir::IntType>(elementType) ||
isAnyFloatingPointType(elementType))
return success();

return emitError() << "unsupported element type for CIR vector";
return success();
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/IR/invalid-vector.cir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

module {

// expected-error @below {{unsupported element type for CIR vector}}
// expected-error @below {{failed to verify 'elementType'}}
cir.global external @vec_b = #cir.zero : !cir.vector<4 x !cir.array<!s32i x 10>>

}