Skip to content

Commit d91ac48

Browse files
committed
[CIR] Implement AddOp for ComplexType
1 parent e29ac9b commit d91ac48

File tree

6 files changed

+285
-0
lines changed

6 files changed

+285
-0
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2521,6 +2521,32 @@ def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
25212521
let hasFolder = 1;
25222522
}
25232523

2524+
//===----------------------------------------------------------------------===//
2525+
// ComplexAddOp
2526+
//===----------------------------------------------------------------------===//
2527+
2528+
def ComplexAddOp : CIR_Op<"complex.add", [Pure, SameOperandsAndResultType]> {
2529+
let summary = "Complex addition";
2530+
let description = [{
2531+
The `cir.complex.add` operation takes two complex numbers and returns
2532+
their sum.
2533+
2534+
Example:
2535+
2536+
```mlir
2537+
%2 = cir.complex.add %0, %1 -> !cir.complex<!cir.float>
2538+
```
2539+
}];
2540+
2541+
let arguments = (ins CIR_ComplexType:$lhs, CIR_ComplexType:$rhs);
2542+
2543+
let results = (outs CIR_ComplexType:$result);
2544+
2545+
let assemblyFormat = [{
2546+
$lhs `,` $rhs `->` qualified(type($result)) attr-dict
2547+
}];
2548+
}
2549+
25242550
//===----------------------------------------------------------------------===//
25252551
// Bit Manipulation Operations
25262552
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,55 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
5757
mlir::Value
5858
VisitSubstNonTypeTemplateParmExpr(SubstNonTypeTemplateParmExpr *e);
5959
mlir::Value VisitUnaryDeref(const Expr *e);
60+
61+
struct BinOpInfo {
62+
mlir::Location loc;
63+
mlir::Value lhs{};
64+
mlir::Value rhs{};
65+
QualType ty{}; // Computation Type.
66+
FPOptions fpFeatures{};
67+
};
68+
69+
BinOpInfo emitBinOps(const BinaryOperator *e,
70+
QualType promotionTy = QualType());
71+
72+
mlir::Value emitPromoted(const Expr *e, QualType promotionTy);
73+
74+
mlir::Value emitPromotedComplexOperand(const Expr *e, QualType promotionTy);
75+
76+
mlir::Value emitBinAdd(const BinOpInfo &op);
77+
78+
QualType getPromotionType(QualType ty, bool isDivOpCode = false) {
79+
if (auto *complexTy = ty->getAs<ComplexType>()) {
80+
QualType elementTy = complexTy->getElementType();
81+
if (isDivOpCode && elementTy->isFloatingType() &&
82+
cgf.getLangOpts().getComplexRange() ==
83+
LangOptions::ComplexRangeKind::CX_Promoted) {
84+
cgf.cgm.errorNYI("HigherPrecisionTypeForComplexArithmetic");
85+
return QualType();
86+
}
87+
88+
if (elementTy.UseExcessPrecision(cgf.getContext()))
89+
return cgf.getContext().getComplexType(cgf.getContext().FloatTy);
90+
}
91+
92+
if (ty.UseExcessPrecision(cgf.getContext()))
93+
return cgf.getContext().FloatTy;
94+
return QualType();
95+
}
96+
97+
#define HANDLEBINOP(OP) \
98+
mlir::Value VisitBin##OP(const BinaryOperator *e) { \
99+
QualType promotionTy = getPromotionType( \
100+
e->getType(), e->getOpcode() == BinaryOperatorKind::BO_Div); \
101+
mlir::Value result = emitBin##OP(emitBinOps(e, promotionTy)); \
102+
if (!promotionTy.isNull()) \
103+
cgf.cgm.errorNYI("Binop emitUnPromotedValue"); \
104+
return result; \
105+
}
106+
107+
HANDLEBINOP(Add)
108+
#undef HANDLEBINOP
60109
};
61110
} // namespace
62111

@@ -291,6 +340,58 @@ mlir::Value ComplexExprEmitter::VisitUnaryDeref(const Expr *e) {
291340
return emitLoadOfLValue(e);
292341
}
293342

343+
mlir::Value ComplexExprEmitter::emitPromoted(const Expr *e,
344+
QualType promotionTy) {
345+
e = e->IgnoreParens();
346+
if (const auto *bo = dyn_cast<BinaryOperator>(e)) {
347+
switch (bo->getOpcode()) {
348+
#define HANDLE_BINOP(OP) \
349+
case BO_##OP: \
350+
return emitBin##OP(emitBinOps(bo, promotionTy));
351+
HANDLE_BINOP(Add)
352+
#undef HANDLE_BINOP
353+
default:
354+
break;
355+
}
356+
} else if (isa<UnaryOperator>(e)) {
357+
cgf.cgm.errorNYI("emitPromoted UnaryOperator");
358+
return {};
359+
}
360+
361+
mlir::Value result = Visit(const_cast<Expr *>(e));
362+
if (!promotionTy.isNull())
363+
cgf.cgm.errorNYI("emitPromoted emitPromotedValue");
364+
365+
return result;
366+
}
367+
368+
mlir::Value
369+
ComplexExprEmitter::emitPromotedComplexOperand(const Expr *e,
370+
QualType promotionTy) {
371+
if (e->getType()->isAnyComplexType()) {
372+
if (!promotionTy.isNull())
373+
return cgf.emitPromotedComplexExpr(e, promotionTy);
374+
return Visit(const_cast<Expr *>(e));
375+
}
376+
377+
cgf.cgm.errorNYI("emitPromotedComplexOperand non-complex type");
378+
return {};
379+
}
380+
381+
ComplexExprEmitter::BinOpInfo
382+
ComplexExprEmitter::emitBinOps(const BinaryOperator *e, QualType promotionTy) {
383+
BinOpInfo binOpInfo{cgf.getLoc(e->getExprLoc())};
384+
binOpInfo.lhs = emitPromotedComplexOperand(e->getLHS(), promotionTy);
385+
binOpInfo.rhs = emitPromotedComplexOperand(e->getRHS(), promotionTy);
386+
binOpInfo.ty = promotionTy.isNull() ? e->getType() : promotionTy;
387+
binOpInfo.fpFeatures = e->getFPFeaturesInEffect(cgf.getLangOpts());
388+
return binOpInfo;
389+
}
390+
391+
mlir::Value ComplexExprEmitter::emitBinAdd(const BinOpInfo &op) {
392+
return builder.create<cir::ComplexAddOp>(op.loc, op.lhs, op.rhs);
393+
}
394+
294395
LValue CIRGenFunction::emitComplexAssignmentLValue(const BinaryOperator *e) {
295396
assert(e->getOpcode() == BO_Assign && "Expected assign op");
296397

@@ -313,3 +414,8 @@ void CIRGenFunction::emitStoreOfComplex(mlir::Location loc, mlir::Value v,
313414
LValue dest, bool isInit) {
314415
ComplexExprEmitter(*this).emitStoreOfComplex(loc, v, dest, isInit);
315416
}
417+
418+
mlir::Value CIRGenFunction::emitPromotedComplexExpr(const Expr *e,
419+
QualType promotionType) {
420+
return ComplexExprEmitter(*this).emitPromoted(e, promotionType);
421+
}

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,8 @@ class CIRGenFunction : public CIRGenTypeCache {
886886
void emitInitializerForField(clang::FieldDecl *field, LValue lhs,
887887
clang::Expr *init);
888888

889+
mlir::Value emitPromotedComplexExpr(const Expr *e, QualType promotionType);
890+
889891
mlir::Value emitPromotedScalarExpr(const Expr *e, QualType promotionType);
890892

891893
/// Emit the computation of the specified expression of scalar type.

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2048,6 +2048,7 @@ void ConvertCIRToLLVMPass::runOnOperation() {
20482048
CIRToLLVMBrOpLowering,
20492049
CIRToLLVMCallOpLowering,
20502050
CIRToLLVMCmpOpLowering,
2051+
CIRToLLVMComplexAddOpLowering,
20512052
CIRToLLVMComplexCreateOpLowering,
20522053
CIRToLLVMComplexImagOpLowering,
20532054
CIRToLLVMComplexRealOpLowering,
@@ -2357,6 +2358,54 @@ mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite(
23572358
return mlir::success();
23582359
}
23592360

2361+
mlir::LogicalResult CIRToLLVMComplexAddOpLowering::matchAndRewrite(
2362+
cir::ComplexAddOp op, OpAdaptor adaptor,
2363+
mlir::ConversionPatternRewriter &rewriter) const {
2364+
mlir::Value lhs = adaptor.getLhs();
2365+
mlir::Value rhs = adaptor.getRhs();
2366+
mlir::Location loc = op.getLoc();
2367+
2368+
auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType());
2369+
mlir::Type complexElemTy =
2370+
getTypeConverter()->convertType(complexType.getElementType());
2371+
auto lhsReal =
2372+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
2373+
auto lhsImag =
2374+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
2375+
auto rhsReal =
2376+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
2377+
auto rhsImag =
2378+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
2379+
2380+
mlir::Value newReal;
2381+
mlir::Value newImag;
2382+
if (complexElemTy.isInteger()) {
2383+
newReal = rewriter.create<mlir::LLVM::AddOp>(loc, complexElemTy, lhsReal,
2384+
rhsReal);
2385+
newImag = rewriter.create<mlir::LLVM::AddOp>(loc, complexElemTy, lhsImag,
2386+
rhsImag);
2387+
} else {
2388+
newReal = rewriter.create<mlir::LLVM::FAddOp>(loc, complexElemTy, lhsReal,
2389+
rhsReal);
2390+
newImag = rewriter.create<mlir::LLVM::FAddOp>(loc, complexElemTy, lhsImag,
2391+
rhsImag);
2392+
}
2393+
2394+
mlir::Type complexLLVMTy =
2395+
getTypeConverter()->convertType(op.getResult().getType());
2396+
auto initialComplex =
2397+
rewriter.create<mlir::LLVM::UndefOp>(op->getLoc(), complexLLVMTy);
2398+
2399+
auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>(
2400+
op->getLoc(), initialComplex, newReal, 0);
2401+
2402+
auto complex = rewriter.create<mlir::LLVM::InsertValueOp>(
2403+
op->getLoc(), realComplex, newImag, 1);
2404+
2405+
rewriter.replaceOp(op, complex);
2406+
return mlir::success();
2407+
}
2408+
23602409
mlir::LogicalResult CIRToLLVMComplexCreateOpLowering::matchAndRewrite(
23612410
cir::ComplexCreateOp op, OpAdaptor adaptor,
23622411
mlir::ConversionPatternRewriter &rewriter) const {

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,16 @@ class CIRToLLVMGetBitfieldOpLowering
523523
mlir::ConversionPatternRewriter &) const override;
524524
};
525525

526+
class CIRToLLVMComplexAddOpLowering
527+
: public mlir::OpConversionPattern<cir::ComplexAddOp> {
528+
public:
529+
using mlir::OpConversionPattern<cir::ComplexAddOp>::OpConversionPattern;
530+
531+
mlir::LogicalResult
532+
matchAndRewrite(cir::ComplexAddOp op, OpAdaptor,
533+
mlir::ConversionPatternRewriter &) const override;
534+
};
535+
526536
} // namespace direct
527537
} // namespace cir
528538

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
3+
// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
4+
// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM
5+
// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -Wno-unused-value -emit-llvm %s -o %t.ll
6+
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
7+
8+
void foo() {
9+
int _Complex a;
10+
int _Complex b;
11+
int _Complex c = a + b;
12+
}
13+
14+
// CIR: %[[COMPLEX_A:.*]] = cir.alloca !cir.complex<!s32i>, !cir.ptr<!cir.complex<!s32i>>, ["a"]
15+
// CIR: %[[COMPLEX_B:.*]] = cir.alloca !cir.complex<!s32i>, !cir.ptr<!cir.complex<!s32i>>, ["b"]
16+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[COMPLEX_A]] : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
17+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[COMPLEX_B]] : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
18+
// CIR: %[[ADD:.*]] = cir.complex.add %[[TMP_A]], %[[TMP_B]] -> !cir.complex<!s32i>
19+
20+
// LLVM: %[[COMPLEX_A:.*]] = alloca { i32, i32 }, i64 1, align 4
21+
// LLVM: %[[COMPLEX_B:.*]] = alloca { i32, i32 }, i64 1, align 4
22+
// LLVM: %[[TMP_A:.*]] = load { i32, i32 }, ptr %[[COMPLEX_A]], align 4
23+
// LLVM: %[[TMP_B:.*]] = load { i32, i32 }, ptr %[[COMPLEX_B]], align 4
24+
// LLVM: %[[A_REAL:.*]] = extractvalue { i32, i32 } %[[TMP_A]], 0
25+
// LLVM: %[[A_IMAG:.*]] = extractvalue { i32, i32 } %[[TMP_A]], 1
26+
// LLVM: %[[B_REAL:.*]] = extractvalue { i32, i32 } %[[TMP_B]], 0
27+
// LLVM: %[[B_IMAG:.*]] = extractvalue { i32, i32 } %[[TMP_B]], 1
28+
// LLVM: %[[ADD_REAL:.*]] = add i32 %[[A_REAL]], %[[B_REAL]]
29+
// LLVM: %[[ADD_IMAG:.*]] = add i32 %[[A_IMAG]], %[[B_IMAG]]
30+
// LLVM: %[[RESULT:.*]] = insertvalue { i32, i32 } undef, i32 %[[ADD_REAL]], 0
31+
// LLVM: %[[RESULT_2:.*]] = insertvalue { i32, i32 } %[[RESULT]], i32 %[[ADD_IMAG]], 1
32+
33+
// OGCG: %[[COMPLEX_A:.*]] = alloca { i32, i32 }, align 4
34+
// OGCG: %[[COMPLEX_B:.*]] = alloca { i32, i32 }, align 4
35+
// OGCG: %[[RESULT:.*]] = alloca { i32, i32 }, align 4
36+
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 0
37+
// OGCG: %[[A_REAL:.*]] = load i32, ptr %[[A_REAL_PTR]], align 4
38+
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 1
39+
// OGCG: %[[A_IMAG:.*]] = load i32, ptr %[[A_IMAG_PTR]], align 4
40+
// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0
41+
// OGCG: %[[B_REAL:.*]] = load i32, ptr %[[B_REAL_PTR]], align 4
42+
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 1
43+
// OGCG: %[[B_IMAG:.*]] = load i32, ptr %[[B_IMAG_PTR]], align 4
44+
// OGCG: %[[ADD_REAL:.*]] = add i32 %[[A_REAL]], %[[B_REAL]]
45+
// OGCG: %[[ADD_IMAG:.*]] = add i32 %[[A_IMAG]], %[[B_IMAG]]
46+
// OGCG: %[[RESULT_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[RESULT]], i32 0, i32 0
47+
// OGCG: %[[RESULT_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[RESULT]], i32 0, i32 1
48+
// OGCG: store i32 %[[ADD_REAL]], ptr %[[RESULT_REAL_PTR]], align 4
49+
// OGCG: store i32 %[[ADD_IMAG]], ptr %[[RESULT_IMAG_PTR]], align 4
50+
51+
void foo2() {
52+
float _Complex a;
53+
float _Complex b;
54+
float _Complex c = a + b;
55+
}
56+
57+
// CIR: %[[COMPLEX_A:.*]] = cir.alloca !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>, ["a"]
58+
// CIR: %[[COMPLEX_B:.*]] = cir.alloca !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>, ["b"]
59+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[COMPLEX_A]] : !cir.ptr<!cir.complex<!cir.float>>, !cir.complex<!cir.float>
60+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[COMPLEX_B]] : !cir.ptr<!cir.complex<!cir.float>>, !cir.complex<!cir.float>
61+
// CIR: %[[ADD:.*]] = cir.complex.add %[[TMP_A]], %[[TMP_B]] -> !cir.complex<!cir.float>
62+
63+
// LLVM: %[[COMPLEX_A:.*]] = alloca { float, float }, i64 1, align 4
64+
// LLVM: %[[COMPLEX_B:.*]] = alloca { float, float }, i64 1, align 4
65+
// LLVM: %[[TMP_A:.*]] = load { float, float }, ptr %[[COMPLEX_A]], align 4
66+
// LLVM: %[[TMP_B:.*]] = load { float, float }, ptr %[[COMPLEX_B]], align 4
67+
// LLVM: %[[A_REAL:.*]] = extractvalue { float, float } %[[TMP_A]], 0
68+
// LLVM: %[[A_IMAG:.*]] = extractvalue { float, float } %[[TMP_A]], 1
69+
// LLVM: %[[B_REAL:.*]] = extractvalue { float, float } %[[TMP_B]], 0
70+
// LLVM: %[[B_IMAG:.*]] = extractvalue { float, float } %[[TMP_B]], 1
71+
// LLVM: %[[ADD_REAL:.*]] = fadd float %[[A_REAL]], %[[B_REAL]]
72+
// LLVM: %[[ADD_IMAG:.*]] = fadd float %[[A_IMAG]], %[[B_IMAG]]
73+
// LLVM: %[[RESULT:.*]] = insertvalue { float, float } undef, float %[[ADD_REAL]], 0
74+
// LLVM: %[[RESULT_2:.*]] = insertvalue { float, float } %[[RESULT]], float %[[ADD_IMAG]], 1
75+
76+
// OGCG: %[[COMPLEX_A:.*]] = alloca { float, float }, align 4
77+
// OGCG: %[[COMPLEX_B:.*]] = alloca { float, float }, align 4
78+
// OGCG: %[[RESULT:.*]] = alloca { float, float }, align 4
79+
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { float, float }, ptr %[[COMPLEX_A]], i32 0, i32 0
80+
// OGCG: %[[A_REAL:.*]] = load float, ptr %[[A_REAL_PTR]], align 4
81+
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { float, float }, ptr %[[COMPLEX_A]], i32 0, i32 1
82+
// OGCG: %[[A_IMAG:.*]] = load float, ptr %[[A_IMAG_PTR]], align 4
83+
// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { float, float }, ptr %[[COMPLEX_B]], i32 0, i32 0
84+
// OGCG: %[[B_REAL:.*]] = load float, ptr %[[B_REAL_PTR]], align 4
85+
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { float, float }, ptr %[[COMPLEX_B]], i32 0, i32 1
86+
// OGCG: %[[B_IMAG:.*]] = load float, ptr %[[B_IMAG_PTR]], align 4
87+
// OGCG: %[[ADD_REAL:.*]] = fadd float %[[A_REAL]], %[[B_REAL]]
88+
// OGCG: %[[ADD_IMAG:.*]] = fadd float %[[A_IMAG]], %[[B_IMAG]]
89+
// OGCG: %[[RESULT_REAL_PTR:.*]] = getelementptr inbounds nuw { float, float }, ptr %[[RESULT]], i32 0, i32 0
90+
// OGCG: %[[RESULT_IMAG_PTR:.*]] = getelementptr inbounds nuw { float, float }, ptr %[[RESULT]], i32 0, i32 1
91+
// OGCG: store float %[[ADD_REAL]], ptr %[[RESULT_REAL_PTR]], align 4
92+
// OGCG: store float %[[ADD_IMAG]], ptr %[[RESULT_IMAG_PTR]], align 4

0 commit comments

Comments
 (0)