Skip to content

Commit 33c950f

Browse files
committed
[CIR] Support type promotion for Scalar unary real & imag ops
1 parent 1b1b83f commit 33c950f

File tree

3 files changed

+115
-28
lines changed

3 files changed

+115
-28
lines changed

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,6 @@ struct MissingFeatures {
289289
static bool scalableVectors() { return false; }
290290
static bool unsizedTypes() { return false; }
291291
static bool vectorType() { return false; }
292-
static bool complexType() { return false; }
293292
static bool fixedPointType() { return false; }
294293
static bool stringTypeWithDifferentArraySize() { return false; }
295294

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -667,8 +667,9 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
667667
mlir::Value VisitUnaryLNot(const UnaryOperator *e);
668668

669669
mlir::Value VisitUnaryReal(const UnaryOperator *e);
670-
671670
mlir::Value VisitUnaryImag(const UnaryOperator *e);
671+
mlir::Value VisitRealImag(const UnaryOperator *e,
672+
QualType promotionType = QualType());
672673

673674
mlir::Value VisitCXXDefaultInitExpr(CXXDefaultInitExpr *die) {
674675
CIRGenFunction::CXXDefaultInitExprScope scope(cgf, die);
@@ -864,11 +865,13 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
864865
// TODO(cir): Candidate to be in a common AST helper between CIR and LLVM
865866
// codegen.
866867
QualType getPromotionType(QualType ty) {
867-
if (ty->getAs<ComplexType>()) {
868-
assert(!cir::MissingFeatures::complexType());
869-
cgf.cgm.errorNYI("promotion to complex type");
870-
return QualType();
868+
const clang::ASTContext &ctx = cgf.getContext();
869+
if (auto *complexTy = ty->getAs<ComplexType>()) {
870+
QualType elementTy = complexTy->getElementType();
871+
if (elementTy.UseExcessPrecision(ctx))
872+
return ctx.getComplexType(ctx.FloatTy);
871873
}
874+
872875
if (ty.UseExcessPrecision(cgf.getContext())) {
873876
if (ty->getAs<VectorType>()) {
874877
assert(!cir::MissingFeatures::vectorType());
@@ -877,6 +880,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
877880
}
878881
return cgf.getContext().FloatTy;
879882
}
883+
880884
return QualType();
881885
}
882886

@@ -2057,28 +2061,27 @@ mlir::Value ScalarExprEmitter::VisitUnaryLNot(const UnaryOperator *e) {
20572061
}
20582062

20592063
mlir::Value ScalarExprEmitter::VisitUnaryReal(const UnaryOperator *e) {
2060-
// TODO(cir): handle scalar promotion.
2061-
Expr *op = e->getSubExpr();
2062-
if (op->getType()->isAnyComplexType()) {
2063-
// If it's an l-value, load through the appropriate subobject l-value.
2064-
// Note that we have to ask `e` because `op` might be an l-value that
2065-
// this won't work for, e.g. an Obj-C property.
2066-
if (e->isGLValue()) {
2067-
mlir::Location loc = cgf.getLoc(e->getExprLoc());
2068-
mlir::Value complex = cgf.emitComplexExpr(op);
2069-
return cgf.builder.createComplexReal(loc, complex);
2070-
}
2071-
2072-
// Otherwise, calculate and project.
2073-
cgf.cgm.errorNYI(e->getSourceRange(),
2074-
"VisitUnaryReal calculate and project");
2075-
}
2076-
2077-
return Visit(op);
2064+
QualType promotionTy = getPromotionType(e->getSubExpr()->getType());
2065+
mlir::Value result = VisitRealImag(e, promotionTy);
2066+
if (result && !promotionTy.isNull())
2067+
result = emitUnPromotedValue(result, e->getType());
2068+
return result;
20782069
}
20792070

20802071
mlir::Value ScalarExprEmitter::VisitUnaryImag(const UnaryOperator *e) {
2081-
// TODO(cir): handle scalar promotion.
2072+
QualType promotionTy = getPromotionType(e->getSubExpr()->getType());
2073+
mlir::Value result = VisitRealImag(e, promotionTy);
2074+
if (result && !promotionTy.isNull())
2075+
result = emitUnPromotedValue(result, e->getType());
2076+
return result;
2077+
}
2078+
2079+
mlir::Value ScalarExprEmitter::VisitRealImag(const UnaryOperator *e,
2080+
QualType promotionTy) {
2081+
assert(e->getOpcode() == clang::UO_Real ||
2082+
e->getOpcode() == clang::UO_Imag &&
2083+
"Invalid UnaryOp kind for ComplexType Real or Imag");
2084+
20822085
Expr *op = e->getSubExpr();
20832086
if (op->getType()->isAnyComplexType()) {
20842087
// If it's an l-value, load through the appropriate subobject l-value.
@@ -2087,15 +2090,26 @@ mlir::Value ScalarExprEmitter::VisitUnaryImag(const UnaryOperator *e) {
20872090
if (e->isGLValue()) {
20882091
mlir::Location loc = cgf.getLoc(e->getExprLoc());
20892092
mlir::Value complex = cgf.emitComplexExpr(op);
2090-
return cgf.builder.createComplexImag(loc, complex);
2093+
if (!promotionTy.isNull()) {
2094+
complex = cgf.emitPromotedValue(complex, promotionTy);
2095+
}
2096+
2097+
return e->getOpcode() == clang::UO_Real
2098+
? builder.createComplexReal(loc, complex)
2099+
: builder.createComplexImag(loc, complex);
20912100
}
20922101

20932102
// Otherwise, calculate and project.
20942103
cgf.cgm.errorNYI(e->getSourceRange(),
2095-
"VisitUnaryImag calculate and project");
2104+
"VisitRealImag calculate and project");
2105+
return {};
20962106
}
20972107

2098-
return Visit(op);
2108+
// __real or __imag on a scalar returns zero. Emit the subexpr to ensure side
2109+
// effects are evaluated, but not the actual value.
2110+
cgf.cgm.errorNYI(e->getSourceRange(),
2111+
"VisitRealImag __real or __imag on a scalar");
2112+
return {};
20992113
}
21002114

21012115
/// Return the size or alignment of the type of argument of the sizeof

clang/test/CIR/CodeGen/complex.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,3 +927,77 @@ void foo34() {
927927
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { float, float }, ptr %[[A_ADDR]], i32 0, i32 1
928928
// OGCG: store float 1.000000e+00, ptr %[[A_REAL_PTR]], align 8
929929
// OGCG: store float 2.000000e+00, ptr %[[A_IMAG_PTR]], align 4
930+
931+
void foo35() {
932+
_Float16 _Complex a;
933+
_Float16 real = __real__ a;
934+
}
935+
936+
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.complex<!cir.f16>, !cir.ptr<!cir.complex<!cir.f16>>, ["a"]
937+
// CIR: %[[REAL_ADDR:.*]] = cir.alloca !cir.f16, !cir.ptr<!cir.f16>, ["real", init]
938+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.complex<!cir.f16>>, !cir.complex<!cir.f16>
939+
// CIR: %[[A_REAL:.*]] = cir.complex.real %[[TMP_A]] : !cir.complex<!cir.f16> -> !cir.f16
940+
// CIR: %[[A_IMAG:.*]] = cir.complex.imag %[[TMP_A]] : !cir.complex<!cir.f16> -> !cir.f16
941+
// CIR: %[[A_REAL_F32:.*]] = cir.cast(floating, %[[A_REAL]] : !cir.f16), !cir.float
942+
// CIR: %[[A_IMAG_F32:.*]] = cir.cast(floating, %[[A_IMAG]] : !cir.f16), !cir.float
943+
// CIR: %[[A_COMPLEX_F32:.*]] = cir.complex.create %[[A_REAL_F32]], %[[A_IMAG_F32]] : !cir.float -> !cir.complex<!cir.float>
944+
// CIR: %[[A_REAL_F32:.*]] = cir.complex.real %[[A_COMPLEX_F32]] : !cir.complex<!cir.float> -> !cir.float
945+
// CIR: %[[A_REAL_F16:.*]] = cir.cast(floating, %[[A_REAL_F32]] : !cir.float), !cir.f16
946+
// CIR: cir.store{{.*}} %[[A_REAL_F16]], %[[REAL_ADDR]] : !cir.f16, !cir.ptr<!cir.f16>
947+
948+
// LLVM: %[[A_ADDR:.*]] = alloca { half, half }, i64 1, align 2
949+
// LLVM: %[[REAL_ADDR:.*]] = alloca half, i64 1, align 2
950+
// LLVM: %[[TMP_A:.*]] = load { half, half }, ptr %[[A_ADDR]], align 2
951+
// LLVM: %[[A_REAL:.*]] = extractvalue { half, half } %[[TMP_A]], 0
952+
// LLVM: %[[A_IMAG:.*]] = extractvalue { half, half } %[[TMP_A]], 1
953+
// LLVM: %[[A_REAL_F32:.*]] = fpext half %[[A_REAL]] to float
954+
// LLVM: %[[A_IMAG_F32:.*]] = fpext half %[[A_IMAG]] to float
955+
// LLVM: %[[TMP_A_COMPLEX_F32:.*]] = insertvalue { float, float } {{.*}}, float %[[A_REAL_F32]], 0
956+
// LLVM: %[[A_COMPLEX_F32:.*]] = insertvalue { float, float } %[[TMP_A_COMPLEX_F32]], float %[[A_IMAG_F32]], 1
957+
// LLVM: %[[A_REAL_F16:.*]] = fptrunc float %[[A_REAL_F32]] to half
958+
// LLVM: store half %[[A_REAL_F16]], ptr %[[REAL_ADDR]], align 2
959+
960+
// OGCG: %[[A_ADDR:.*]] = alloca { half, half }, align 2
961+
// OGCG: %[[REAL_ADDR:.*]] = alloca half, align 2
962+
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { half, half }, ptr %[[A_ADDR]], i32 0, i32 0
963+
// OGCG: %[[A_REAL:.*]] = load half, ptr %[[A_REAL_PTR]], align 2
964+
// OGCG: %[[A_REAL_F32:.*]] = fpext half %[[A_REAL]] to float
965+
// OGCG: %[[A_REAL_F16:.*]] = fptrunc float %[[A_REAL_F32]] to half
966+
// OGCG: store half %[[A_REAL_F16]], ptr %[[REAL_ADDR]], align 2
967+
968+
void foo36() {
969+
_Float16 _Complex a;
970+
_Float16 imag = __imag__ a;
971+
}
972+
973+
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.complex<!cir.f16>, !cir.ptr<!cir.complex<!cir.f16>>, ["a"]
974+
// CIR: %[[IMAG_ADDR:.*]] = cir.alloca !cir.f16, !cir.ptr<!cir.f16>, ["imag", init]
975+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.complex<!cir.f16>>, !cir.complex<!cir.f16>
976+
// CIR: %[[A_REAL:.*]] = cir.complex.real %[[TMP_A]] : !cir.complex<!cir.f16> -> !cir.f16
977+
// CIR: %[[A_IMAG:.*]] = cir.complex.imag %[[TMP_A]] : !cir.complex<!cir.f16> -> !cir.f16
978+
// CIR: %[[A_REAL_F32:.*]] = cir.cast(floating, %[[A_REAL]] : !cir.f16), !cir.float
979+
// CIR: %[[A_IMAG_F32:.*]] = cir.cast(floating, %[[A_IMAG]] : !cir.f16), !cir.float
980+
// CIR: %[[A_COMPLEX_F32:.*]] = cir.complex.create %[[A_REAL_F32]], %[[A_IMAG_F32]] : !cir.float -> !cir.complex<!cir.float>
981+
// CIR: %[[A_IMAG_F32:.*]] = cir.complex.imag %[[A_COMPLEX_F32]] : !cir.complex<!cir.float> -> !cir.float
982+
// CIR: %[[A_IMAG_F16:.*]] = cir.cast(floating, %[[A_IMAG_F32]] : !cir.float), !cir.f16
983+
// CIR: cir.store{{.*}} %[[A_IMAG_F16]], %[[IMAG_ADDR]] : !cir.f16, !cir.ptr<!cir.f16>
984+
985+
// LLVM: %[[A_ADDR:.*]] = alloca { half, half }, i64 1, align 2
986+
// LLVM: %[[IMAG_ADDR:.*]] = alloca half, i64 1, align 2
987+
// LLVM: %[[TMP_A:.*]] = load { half, half }, ptr %[[A_ADDR]], align 2
988+
// LLVM: %[[A_REAL:.*]] = extractvalue { half, half } %[[TMP_A]], 0
989+
// LLVM: %[[A_IMAG:.*]] = extractvalue { half, half } %[[TMP_A]], 1
990+
// LLVM: %[[A_REAL_F32:.*]] = fpext half %[[A_REAL]] to float
991+
// LLVM: %[[A_IMAG_F32:.*]] = fpext half %[[A_IMAG]] to float
992+
// LLVM: %[[TMP_A_COMPLEX_F32:.*]] = insertvalue { float, float } {{.*}}, float %[[A_REAL_F32]], 0
993+
// LLVM: %[[A_COMPLEX_F32:.*]] = insertvalue { float, float } %[[TMP_A_COMPLEX_F32]], float %[[A_IMAG_F32]], 1
994+
// LLVM: %[[A_IMAG_F16:.*]] = fptrunc float %[[A_IMAG_F32]] to half
995+
// LLVM: store half %[[A_IMAG_F16]], ptr %[[IMAG_ADDR]], align 2
996+
997+
// OGCG: %[[A_ADDR:.*]] = alloca { half, half }, align 2
998+
// OGCG: %[[IMAG_ADDR:.*]] = alloca half, align 2
999+
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { half, half }, ptr %[[A_ADDR]], i32 0, i32 1
1000+
// OGCG: %[[A_IMAG:.*]] = load half, ptr %[[A_IMAG_PTR]], align 2
1001+
// OGCG: %[[A_IMAG_F32:.*]] = fpext half %[[A_IMAG]] to float
1002+
// OGCG: %[[A_IMAG_F16:.*]] = fptrunc float %[[A_IMAG_F32]] to half
1003+
// OGCG: store half %[[A_IMAG_F16]], ptr %[[IMAG_ADDR]], align 2

0 commit comments

Comments
 (0)