Skip to content

Commit 3f4e35e

Browse files
committed
[CIR] Support type promotion for Scalar unary plus & minus ops
1 parent 1b1b83f commit 3f4e35e

File tree

2 files changed

+98
-13
lines changed

2 files changed

+98
-13
lines changed

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -621,19 +621,27 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
621621
}
622622

623623
mlir::Value VisitUnaryPlus(const UnaryOperator *e) {
624-
return emitUnaryPlusOrMinus(e, cir::UnaryOpKind::Plus);
624+
QualType promotionType = getPromotionType(e->getSubExpr()->getType());
625+
mlir::Value result =
626+
emitUnaryPlusOrMinus(e, cir::UnaryOpKind::Plus, promotionType);
627+
if (result && !promotionType.isNull())
628+
return emitUnPromotedValue(result, e->getType());
629+
return result;
625630
}
626631

627632
mlir::Value VisitUnaryMinus(const UnaryOperator *e) {
628-
return emitUnaryPlusOrMinus(e, cir::UnaryOpKind::Minus);
633+
QualType promotionType = getPromotionType(e->getSubExpr()->getType());
634+
mlir::Value result =
635+
emitUnaryPlusOrMinus(e, cir::UnaryOpKind::Minus, promotionType);
636+
if (result && !promotionType.isNull())
637+
return emitUnPromotedValue(result, e->getType());
638+
return result;
629639
}
630640

631641
mlir::Value emitUnaryPlusOrMinus(const UnaryOperator *e,
632-
cir::UnaryOpKind kind) {
642+
cir::UnaryOpKind kind,
643+
QualType promotionType) {
633644
ignoreResultAssign = false;
634-
635-
QualType promotionType = getPromotionType(e->getSubExpr()->getType());
636-
637645
mlir::Value operand;
638646
if (!promotionType.isNull())
639647
operand = cgf.emitPromotedScalarExpr(e->getSubExpr(), promotionType);
@@ -645,10 +653,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
645653

646654
// NOTE: LLVM codegen will lower this directly to either a FNeg
647655
// or a Sub instruction. In CIR this will be handled later in LowerToLLVM.
648-
mlir::Value result = emitUnaryOp(e, kind, operand, nsw);
649-
if (result && !promotionType.isNull())
650-
return emitUnPromotedValue(result, e->getType());
651-
return result;
656+
return emitUnaryOp(e, kind, operand, nsw);
652657
}
653658

654659
mlir::Value emitUnaryOp(const UnaryOperator *e, cir::UnaryOpKind kind,
@@ -1239,9 +1244,23 @@ mlir::Value ScalarExprEmitter::emitPromoted(const Expr *e,
12391244
default:
12401245
break;
12411246
}
1242-
} else if (isa<UnaryOperator>(e)) {
1243-
cgf.cgm.errorNYI(e->getSourceRange(), "unary operators");
1244-
return {};
1247+
} else if (const auto *uo = dyn_cast<UnaryOperator>(e)) {
1248+
switch (uo->getOpcode()) {
1249+
case UO_Imag:
1250+
cgf.cgm.errorNYI(e->getSourceRange(),
1251+
"ScalarExprEmitter::emitPromoted unary imag");
1252+
return {};
1253+
case UO_Real:
1254+
cgf.cgm.errorNYI(e->getSourceRange(),
1255+
"ScalarExprEmitter::emitPromoted unary real");
1256+
return {};
1257+
case UO_Minus:
1258+
return emitUnaryPlusOrMinus(uo, cir::UnaryOpKind::Minus, promotionType);
1259+
case UO_Plus:
1260+
return emitUnaryPlusOrMinus(uo, cir::UnaryOpKind::Plus, promotionType);
1261+
default:
1262+
break;
1263+
}
12451264
}
12461265
mlir::Value result = Visit(const_cast<Expr *>(e));
12471266
if (result) {

clang/test/CIR/CodeGen/unary.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,3 +556,69 @@ void test_logical_not() {
556556
// OGCG: %[[D_NOT:.*]] = xor i1 %[[D_BOOL]], true
557557
// OGCG: %[[D_CAST:.*]] = zext i1 %[[D_NOT]] to i8
558558
// OGCG: store i8 %[[D_CAST]], ptr %[[B_ADDR]], align 1
559+
560+
void f16NestedUPlus() {
561+
_Float16 a;
562+
_Float16 b = +(+a);
563+
}
564+
565+
// CHECK: cir.func{{.*}} @_Z14f16NestedUPlusv()
566+
// CHECK: %[[A_ADDR:.*]] = cir.alloca !cir.f16, !cir.ptr<!cir.f16>, ["a"]
567+
// CHECK: %[[B_ADDR:.*]] = cir.alloca !cir.f16, !cir.ptr<!cir.f16>, ["b", init]
568+
// CHECK: %[[TMP_A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.f16>, !cir.f16
569+
// CHECK: %[[A_F32:.*]] = cir.cast(floating, %[[TMP_A]] : !cir.f16), !cir.float
570+
// CHECK: %[[A_PLUS:.*]] = cir.unary(plus, %[[A_F32]]) : !cir.float, !cir.float
571+
// CHECK: %[[RESULT_F32:.*]] = cir.unary(plus, %[[A_PLUS]]) : !cir.float, !cir.float
572+
// CHECK: %[[RESULT:.*]] = cir.cast(floating, %[[RESULT_F32]] : !cir.float), !cir.f16
573+
// CHECK: cir.store{{.*}} %[[RESULT]], %[[B_ADDR]] : !cir.f16, !cir.ptr<!cir.f16>
574+
575+
// LLVM: define{{.*}} void @_Z14f16NestedUPlusv()
576+
// LLVM: %[[A_ADDR:.*]] = alloca half, i64 1, align 2
577+
// LLVM: %[[B_ADDR:.*]] = alloca half, i64 1, align 2
578+
// LLVM: %[[TMP_A:.*]] = load half, ptr %[[A_ADDR]], align 2
579+
// LLVM: %[[RESULT_F32:.*]] = fpext half %[[TMP_A]] to float
580+
// LLVM: %[[RESULT:.*]] = fptrunc float %[[RESULT_F32]] to half
581+
// LLVM: store half %[[RESULT]], ptr %[[B_ADDR]], align 2
582+
583+
// OGCG: define{{.*}} void @_Z14f16NestedUPlusv()
584+
// OGCG: %[[A_ADDR:.*]] = alloca half, align 2
585+
// OGCG: %[[B_ADDR:.*]] = alloca half, align 2
586+
// OGCG: %[[TMP_A:.*]] = load half, ptr %[[A_ADDR]], align 2
587+
// OGCG: %[[RESULT_F32:.*]] = fpext half %[[TMP_A]] to float
588+
// OGCG: %[[RESULT:.*]] = fptrunc float %[[RESULT_F32]] to half
589+
// OGCG: store half %[[RESULT]], ptr %[[B_ADDR]], align 2
590+
591+
void f16NestedUMinus() {
592+
_Float16 a;
593+
_Float16 b = -(-a);
594+
}
595+
596+
// CHECK: cir.func{{.*}} @_Z15f16NestedUMinusv()
597+
// CHECK: %[[A_ADDR:.*]] = cir.alloca !cir.f16, !cir.ptr<!cir.f16>, ["a"]
598+
// CHECK: %[[B_ADDR:.*]] = cir.alloca !cir.f16, !cir.ptr<!cir.f16>, ["b", init]
599+
// CHECK: %[[TMP_A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.f16>, !cir.f16
600+
// CHECK: %[[A_F32:.*]] = cir.cast(floating, %[[TMP_A]] : !cir.f16), !cir.float
601+
// CHECK: %[[A_MINUS:.*]] = cir.unary(minus, %[[A_F32]]) : !cir.float, !cir.float
602+
// CHECK: %[[RESULT_F32:.*]] = cir.unary(minus, %[[A_MINUS]]) : !cir.float, !cir.float
603+
// CHECK: %[[RESULT:.*]] = cir.cast(floating, %[[RESULT_F32]] : !cir.float), !cir.f16
604+
// CHECK: cir.store{{.*}} %[[RESULT]], %[[B_ADDR]] : !cir.f16, !cir.ptr<!cir.f16>
605+
606+
// LLVM: define{{.*}} void @_Z15f16NestedUMinusv()
607+
// LLVM: %[[A_ADDR:.*]] = alloca half, i64 1, align 2
608+
// LLVM: %[[B_ADDR:.*]] = alloca half, i64 1, align 2
609+
// LLVM: %[[TMP_A:.*]] = load half, ptr %[[A_ADDR]], align 2
610+
// LLVM: %[[A_F32:.*]] = fpext half %[[TMP_A]] to float
611+
// LLVM: %[[A_MINUS:.*]] = fneg float %[[A_F32]]
612+
// LLVM: %[[RESULT_F32:.*]] = fneg float %[[A_MINUS]]
613+
// LLVM: %[[RESULT:.*]] = fptrunc float %[[RESULT_F32]] to half
614+
// LLVM: store half %[[RESULT]], ptr %[[B_ADDR]], align 2
615+
616+
// OGCG: define{{.*}} void @_Z15f16NestedUMinusv()
617+
// OGCG: %[[A_ADDR:.*]] = alloca half, align 2
618+
// OGCG: %[[B_ADDR:.*]] = alloca half, align 2
619+
// OGCG: %[[TMP_A:.*]] = load half, ptr %[[A_ADDR]], align 2
620+
// OGCG: %[[A_F32:.*]] = fpext half %[[TMP_A]] to float
621+
// OGCG: %[[A_MINUS:.*]] = fneg float %[[A_F32]]
622+
// OGCG: %[[RESULT_F32:.*]] = fneg float %[[A_MINUS]]
623+
// OGCG: %[[RESULT:.*]] = fptrunc float %[[RESULT_F32]] to half
624+
// OGCG: store half %[[RESULT]], ptr %[[B_ADDR]], align 2

0 commit comments

Comments
 (0)