Skip to content

Commit 9fd2f92

Browse files
committed
[CIR] Upstream unary operators for VectorType
1 parent 369891b commit 9fd2f92

File tree

3 files changed

+149
-17
lines changed

3 files changed

+149
-17
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ namespace direct {
5454
namespace {
5555
/// If the given type is a vector type, return the vector's element type.
5656
/// Otherwise return the given type unchanged.
57-
// TODO(cir): Return the vector element type once we have support for vectors
58-
// instead of the identity type.
5957
mlir::Type elementTypeIfVector(mlir::Type type) {
60-
assert(!cir::MissingFeatures::vectorType());
58+
if (const auto vecType = mlir::dyn_cast<cir::VectorType>(type)) {
59+
return vecType.getElementType();
60+
}
6161
return type;
6262
}
6363
} // namespace
@@ -1043,12 +1043,11 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
10431043
mlir::ConversionPatternRewriter &rewriter) const {
10441044
assert(op.getType() == op.getInput().getType() &&
10451045
"Unary operation's operand type and result type are different");
1046-
mlir::Type type = op.getType();
1047-
mlir::Type elementType = type;
1048-
bool isVector = false;
1049-
assert(!cir::MissingFeatures::vectorType());
1050-
mlir::Type llvmType = getTypeConverter()->convertType(type);
1051-
mlir::Location loc = op.getLoc();
1046+
const mlir::Type type = op.getType();
1047+
const mlir::Type elementType = elementTypeIfVector(type);
1048+
const bool isVector = mlir::isa<cir::VectorType>(type);
1049+
const mlir::Type llvmType = getTypeConverter()->convertType(type);
1050+
const mlir::Location loc = op.getLoc();
10521051

10531052
// Integer unary operations: + - ~ ++ --
10541053
if (mlir::isa<cir::IntType>(elementType)) {
@@ -1076,20 +1075,41 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
10761075
rewriter.replaceOp(op, adaptor.getInput());
10771076
return mlir::success();
10781077
case cir::UnaryOpKind::Minus: {
1079-
assert(!isVector &&
1080-
"Add vector handling when vector types are supported");
1081-
mlir::LLVM::ConstantOp zero = rewriter.create<mlir::LLVM::ConstantOp>(
1082-
loc, llvmType, mlir::IntegerAttr::get(llvmType, 0));
1078+
mlir::Value zero;
1079+
if (isVector)
1080+
zero = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmType);
1081+
else
1082+
zero = rewriter.create<mlir::LLVM::ConstantOp>(
1083+
loc, llvmType, mlir::IntegerAttr::get(llvmType, 0));
10831084
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(
10841085
op, llvmType, zero, adaptor.getInput(), maybeNSW);
10851086
return mlir::success();
10861087
}
10871088
case cir::UnaryOpKind::Not: {
10881089
// bit-wise compliment operator, implemented as an XOR with -1.
1089-
assert(!isVector &&
1090-
"Add vector handling when vector types are supported");
1091-
mlir::LLVM::ConstantOp minusOne = rewriter.create<mlir::LLVM::ConstantOp>(
1092-
loc, llvmType, mlir::IntegerAttr::get(llvmType, -1));
1090+
mlir::Value minusOne;
1091+
if (isVector) {
1092+
// Creating a vector object with all -1 values is easier said than
1093+
// done. It requires a series of insertelement ops.
1094+
const mlir::Type llvmElementType =
1095+
getTypeConverter()->convertType(elementType);
1096+
const mlir::Value minusOneInt = rewriter.create<mlir::LLVM::ConstantOp>(
1097+
loc, llvmElementType, mlir::IntegerAttr::get(llvmElementType, -1));
1098+
minusOne = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmType);
1099+
1100+
const uint64_t numElements =
1101+
mlir::dyn_cast<cir::VectorType>(type).getSize();
1102+
for (uint64_t i = 0; i < numElements; ++i) {
1103+
const mlir::Value indexValue =
1104+
rewriter.create<mlir::LLVM::ConstantOp>(loc,
1105+
rewriter.getI64Type(), i);
1106+
minusOne = rewriter.create<mlir::LLVM::InsertElementOp>(
1107+
loc, minusOne, minusOneInt, indexValue);
1108+
}
1109+
} else {
1110+
minusOne = rewriter.create<mlir::LLVM::ConstantOp>(
1111+
loc, llvmType, mlir::IntegerAttr::get(llvmType, -1));
1112+
}
10931113
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(
10941114
op, llvmType, adaptor.getInput(), minusOne);
10951115
return mlir::success();

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,59 @@ void foo4() {
213213
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
214214
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
215215
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
216+
217+
void foo8() {
218+
vi4 a = { 1, 2, 3, 4 };
219+
vi4 plus_res = +a;
220+
vi4 minus_res = -a;
221+
vi4 not_res = ~a;
222+
}
223+
224+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
225+
// CIR: %[[PLUS_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["plus_res", init]
226+
// CIR: %[[MINUS_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["minus_res", init]
227+
// CIR: %[[NOT_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["not_res", init]
228+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
229+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
230+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
231+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
232+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
233+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
234+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
235+
// CIR: %[[TMP1:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
236+
// CIR: %[[PLUS:.*]] = cir.unary(plus, %[[TMP1]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
237+
// CIR: cir.store %[[PLUS]], %[[PLUS_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
238+
// CIR: %[[TMP2:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
239+
// CIR: %[[MINUS:.*]] = cir.unary(minus, %[[TMP2]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
240+
// CIR: cir.store %[[MINUS]], %[[MINUS_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
241+
// CIR: %[[TMP3:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
242+
// CIR: %[[NOT:.*]] = cir.unary(not, %[[TMP3]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
243+
// CIR: cir.store %[[NOT]], %[[NOT_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
244+
245+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
246+
// LLVM: %[[PLUS_RES:.*]] = alloca <4 x i32>, i64 1, align 16
247+
// LLVM: %[[MINUS_RES:.*]] = alloca <4 x i32>, i64 1, align 16
248+
// LLVM: %[[NOT_RES:.*]] = alloca <4 x i32>, i64 1, align 16
249+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
250+
// LLVM: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
251+
// LLVM: store <4 x i32> %[[TMP1]], ptr %[[PLUS_RES]], align 16
252+
// LLVM: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
253+
// LLVM: %[[SUB:.*]] = sub <4 x i32> zeroinitializer, %[[TMP2]]
254+
// LLVM: store <4 x i32> %[[SUB]], ptr %[[MINUS_RES]], align 16
255+
// LLVM: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
256+
// LLVM: %[[NOT:.*]] = xor <4 x i32> %[[TMP3]], splat (i32 -1)
257+
// LLVM: store <4 x i32> %[[NOT]], ptr %[[NOT_RES]], align 16
258+
259+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
260+
// OGCG: %[[PLUS_RES:.*]] = alloca <4 x i32>, align 16
261+
// OGCG: %[[MINUS_RES:.*]] = alloca <4 x i32>, align 16
262+
// OGCG: %[[NOT_RES:.*]] = alloca <4 x i32>, align 16
263+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
264+
// OGCG: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
265+
// OGCG: store <4 x i32> %[[TMP1]], ptr %[[PLUS_RES]], align 16
266+
// OGCG: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
267+
// OGCG: %[[SUB:.*]] = sub <4 x i32> zeroinitializer, %[[TMP2]]
268+
// OGCG: store <4 x i32> %[[SUB]], ptr %[[MINUS_RES]], align 16
269+
// OGCG: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
270+
// OGCG: %[[NOT:.*]] = xor <4 x i32> %[[TMP3]], splat (i32 -1)
271+
// OGCG: store <4 x i32> %[[NOT]], ptr %[[NOT_RES]], align 16

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,59 @@ void foo4() {
201201
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
202202
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
203203
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
204+
205+
void foo8() {
206+
vi4 a = { 1, 2, 3, 4 };
207+
vi4 plus_res = +a;
208+
vi4 minus_res = -a;
209+
vi4 not_res = ~a;
210+
}
211+
212+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
213+
// CIR: %[[PLUS_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["plus_res", init]
214+
// CIR: %[[MINUS_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["minus_res", init]
215+
// CIR: %[[NOT_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["not_res", init]
216+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
217+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
218+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
219+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
220+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
221+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
222+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
223+
// CIR: %[[TMP1:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
224+
// CIR: %[[PLUS:.*]] = cir.unary(plus, %[[TMP1]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
225+
// CIR: cir.store %[[PLUS]], %[[PLUS_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
226+
// CIR: %[[TMP2:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
227+
// CIR: %[[MINUS:.*]] = cir.unary(minus, %[[TMP2]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
228+
// CIR: cir.store %[[MINUS]], %[[MINUS_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
229+
// CIR: %[[TMP3:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
230+
// CIR: %[[NOT:.*]] = cir.unary(not, %[[TMP3]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
231+
// CIR: cir.store %[[NOT]], %[[NOT_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
232+
233+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
234+
// LLVM: %[[PLUS_RES:.*]] = alloca <4 x i32>, i64 1, align 16
235+
// LLVM: %[[MINUS_RES:.*]] = alloca <4 x i32>, i64 1, align 16
236+
// LLVM: %[[NOT_RES:.*]] = alloca <4 x i32>, i64 1, align 16
237+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
238+
// LLVM: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
239+
// LLVM: store <4 x i32> %[[TMP1]], ptr %[[PLUS_RES]], align 16
240+
// LLVM: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
241+
// LLVM: %[[SUB:.*]] = sub <4 x i32> zeroinitializer, %[[TMP2]]
242+
// LLVM: store <4 x i32> %[[SUB]], ptr %[[MINUS_RES]], align 16
243+
// LLVM: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
244+
// LLVM: %[[NOT:.*]] = xor <4 x i32> %[[TMP3]], splat (i32 -1)
245+
// LLVM: store <4 x i32> %[[NOT]], ptr %[[NOT_RES]], align 16
246+
247+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
248+
// OGCG: %[[PLUS_RES:.*]] = alloca <4 x i32>, align 16
249+
// OGCG: %[[MINUS_RES:.*]] = alloca <4 x i32>, align 16
250+
// OGCG: %[[NOT_RES:.*]] = alloca <4 x i32>, align 16
251+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
252+
// OGCG: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
253+
// OGCG: store <4 x i32> %[[TMP1]], ptr %[[PLUS_RES]], align 16
254+
// OGCG: %[[TMP2:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
255+
// OGCG: %[[SUB:.*]] = sub <4 x i32> zeroinitializer, %[[TMP2]]
256+
// OGCG: store <4 x i32> %[[SUB]], ptr %[[MINUS_RES]], align 16
257+
// OGCG: %[[TMP3:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
258+
// OGCG: %[[NOT:.*]] = xor <4 x i32> %[[TMP3]], splat (i32 -1)
259+
// OGCG: store <4 x i32> %[[NOT]], ptr %[[NOT_RES]], align 16

0 commit comments

Comments
 (0)