Skip to content

Commit a4a64fb

Browse files
authored
[CIR] Backport fold Real & Imag Ops with ComplexCreateOp operand (#1847)
Backport fold Real & Imag Ops with ComplexCreateOp operand from the upstream
1 parent 90eef9f commit a4a64fb

File tree

4 files changed

+46
-27
lines changed

4 files changed

+46
-27
lines changed

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -946,10 +946,11 @@ LogicalResult cir::ComplexRealOp::verify() {
946946
}
947947

948948
OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
949-
auto input = mlir::cast_if_present<cir::ComplexAttr>(adaptor.getOperand());
950-
if (input)
951-
return input.getReal();
952-
return nullptr;
949+
if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
950+
return complexCreateOp.getOperand(0);
951+
952+
auto complex = mlir::cast_if_present<cir::ComplexAttr>(adaptor.getOperand());
953+
return complex ? complex.getReal() : nullptr;
953954
}
954955

955956
LogicalResult cir::ComplexImagOp::verify() {
@@ -961,10 +962,11 @@ LogicalResult cir::ComplexImagOp::verify() {
961962
}
962963

963964
OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
964-
auto input = mlir::cast_if_present<cir::ComplexAttr>(adaptor.getOperand());
965-
if (input)
966-
return input.getImag();
967-
return nullptr;
965+
if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
966+
return complexCreateOp.getOperand(1);
967+
968+
auto complex = mlir::cast_if_present<cir::ComplexAttr>(adaptor.getOperand());
969+
return complex ? complex.getImag() : nullptr;
968970
}
969971

970972
//===----------------------------------------------------------------------===//

clang/test/CIR/CodeGen/complex-arithmetic.c

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,6 @@ void add_assign() {
581581

582582
// CHECK: }
583583

584-
585584
void add_assign_float16() {
586585
_Float16 _Complex a;
587586
_Float16 _Complex b;
@@ -606,7 +605,7 @@ void add_assign_float16() {
606605
// CIR: %[[A_IMAG:.*]] = cir.complex.imag %[[TMP_A]] : !cir.complex<!cir.f16> -> !cir.f16
607606
// CIR: %[[A_REAL_F32:.*]] = cir.cast(floating, %[[A_REAL]] : !cir.f16), !cir.float
608607
// CIR: %[[A_IMAG_F32:.*]] = cir.cast(floating, %[[A_IMAG]] : !cir.f16), !cir.float
609-
// CIR: %[[A_F32_COMPLEX:.*]] = cir.complex.create %[[A_REAL_F32]], %[[A_IMAG_F32]] : !cir.float -> !cir.complex<!cir.float>
608+
// CIR: %[[A_F32_COMPLEX:.*]] = cir.complex.create %[[A_REAL_F32]], %[[A_IMAG_F32]] : !cir.float -> !cir.complex<!cir.float>
610609
// CIR: %[[A_F32_REAL:.*]] = cir.complex.real %[[A_F32_COMPLEX]] : !cir.complex<!cir.float> -> !cir.float
611610
// CIR: %[[A_F32_IMAG:.*]] = cir.complex.imag %[[A_F32_COMPLEX]] : !cir.complex<!cir.float> -> !cir.float
612611
// CIR: %[[B_F32_REAL:.*]] = cir.complex.real %[[B_F32_COMPLEX]] : !cir.complex<!cir.float> -> !cir.float
@@ -631,14 +630,15 @@ void add_assign_float16() {
631630
// LLVM: %[[A_IMAG_F32:.*]] = fpext half %[[A_IMAG]] to float
632631
// LLVM: %[[TMP_A_COMPLEX_F32:.*]] = insertvalue { float, float } {{.*}}, float %[[A_REAL_F32]], 0
633632
// LLVM: %[[A_COMPLEX_F32:.*]] = insertvalue { float, float } %[[TMP_A_COMPLEX_F32]], float %[[A_IMAG_F32]], 1
634-
// LLVM: %[[A_F32_REAL:.*]] = extractvalue { float, float } %[[A_COMPLEX_F32]], 0
635-
// LLVM: %[[A_F32_IMAG:.*]] = extractvalue { float, float } %[[A_COMPLEX_F32]], 1
636-
// LLVM: %[[B_F32_REAL:.*]] = extractvalue { float, float } %[[B_COMPLEX_F32]], 0
637-
// LLVM: %[[B_F32_IMAG:.*]] = extractvalue { float, float } %[[B_COMPLEX_F32]], 1
638-
// LLVM: %[[ADD_REAL:.*]] = fadd float %[[A_F32_REAL]], %[[B_F32_REAL]]
639-
// LLVM: %[[ADD_IMAG:.*]] = fadd float %[[A_F32_IMAG]], %[[B_F32_IMAG]]
640-
// LLVM: %[[TMP_RESULT:.*]] = insertvalue { float, float } {{.*}}, float %[[ADD_REAL]], 0
641-
// LLVM: %[[RESULT:.*]] = insertvalue { float, float } %[[TMP_RESULT]], float %[[ADD_IMAG]], 1
633+
// LLVM: %[[RESULT_REAL_F32:.*]] = fadd float %[[A_REAL_F32]], %[[B_REAL_F32]]
634+
// LLVM: %[[RESULT_IMAG_F32:.*]] = fadd float %[[A_IMAG_F32]], %[[B_IMAG_F32]]
635+
// LLVM: %[[TMP_RESULT:.*]] = insertvalue { float, float } {{.*}}, float %[[RESULT_REAL_F32]], 0
636+
// LLVM: %[[RESULT:.*]] = insertvalue { float, float } %[[TMP_RESULT]], float %[[RESULT_IMAG_F32]], 1
637+
// LLVM: %[[RESULT_REAL_F16:.*]] = fptrunc float %[[RESULT_REAL_F32]] to half
638+
// LLVM: %[[RESULT_IMAG_F16:.*]] = fptrunc float %[[RESULT_IMAG_F32]] to half
639+
// LLVM: %[[TMP_RESULT_F16:.*]] = insertvalue { half, half } {{.*}}, half %[[RESULT_REAL_F16]], 0
640+
// LLVM: %[[RESULT_F16:.*]] = insertvalue { half, half } %[[TMP_RESULT_F16]], half %[[RESULT_IMAG_F16]], 1
641+
// LLVM: store { half, half } %[[RESULT_F16]], ptr %[[A_ADDR]], align 2
642642

643643
// CHECK: }
644644

clang/test/CIR/CodeGen/complex-compound-assignment.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,14 @@ void foo() {
3535
// C_LLVM: %[[B_ADDR:.*]] = alloca float, i64 1, align 4
3636
// C_LLVM: %[[TMP_A:.*]] = load { float, float }, ptr %[[A_ADDR]], align 4
3737
// C_LLVM: %[[TMP_B:.*]] = load float, ptr %[[B_ADDR]], align 4
38-
// C_LLVM: %[[TMP_B_COMPLEX:.*]] = insertvalue { float, float } {{.*}}, float %[[TMP_B]], 0
39-
// C_LLVM: %[[B_COMPLEX:.*]] = insertvalue { float, float } %[[TMP_B_COMPLEX]], float 0.000000e+00, 1
40-
// C_LLVM: %[[B_REAL:.*]] = extractvalue { float, float } %[[B_COMPLEX]], 0
41-
// C_LLVM: %[[B_IMAG:.*]] = extractvalue { float, float } %[[B_COMPLEX]], 1
38+
// C_LLVM: %[[TMP_COMPLEX_B:.*]] = insertvalue { float, float } {{.*}}, float %[[TMP_B]], 0
39+
// C_LLVM: %[[COMPLEX_B:.*]] = insertvalue { float, float } %[[TMP_COMPLEX_B]], float 0.000000e+00, 1
4240
// C_LLVM: %[[A_REAL:.*]] = extractvalue { float, float } %[[TMP_A]], 0
4341
// C_LLVM: %[[A_IMAG:.*]] = extractvalue { float, float } %[[TMP_A]], 1
44-
// C_LLVM: %[[ADD_REAL:.*]] = fadd float %[[B_REAL]], %[[A_REAL]]
45-
// C_LLVM: %[[ADD_IMAG:.*]] = fadd float %[[B_IMAG]], %[[A_IMAG]]
46-
// C_LLVM: %[[TMP_RESULT_COMPLEX:.*]] = insertvalue { float, float } {{.*}}, float %[[ADD_REAL]], 0
47-
// C_LLVM: %[[RESULT_COMPLEX:.*]] = insertvalue { float, float } %[[TMP_RESULT_COMPLEX]], float %[[ADD_IMAG]], 1
48-
// C_LLVM: %[[RESULT_REAL:.*]] = extractvalue { float, float } %[[RESULT_COMPLEX]], 0
42+
// C_LLVM: %[[RESULT_REAL:.*]] = fadd float %[[TMP_B]], %[[A_REAL]]
43+
// C_LLVM: %[[RESULT_IMAG:.*]] = fadd float 0.000000e+00, %[[A_IMAG]]
44+
// C_LLVM: %[[TMP_RESULT:.*]] = insertvalue { float, float } {{.*}}, float %[[RESULT_REAL]], 0
45+
// C_LLVM: %[[RESULT:.*]] = insertvalue { float, float } %[[TMP_RESULT]], float %[[RESULT_IMAG]], 1
4946
// C_LLVM: store float %[[RESULT_REAL]], ptr %[[B_ADDR]], align 4
5047

5148
// C_OGCG: %[[A_ADDR:.*]] = alloca { float, float }, align 4

clang/test/CIR/Transforms/complex-fold.cir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,16 @@ module {
2929
// CHECK-NEXT: cir.return %[[#A]] : !s32i
3030
// CHECK-NEXT: }
3131

32+
cir.func @fold_complex_real_from_create_test(%arg0: !s32i, %arg1: !s32i) -> !s32i {
33+
%0 = cir.complex.create %arg0, %arg1 : !s32i -> !cir.complex<!s32i>
34+
%1 = cir.complex.real %0 : !cir.complex<!s32i> -> !s32i
35+
cir.return %1 : !s32i
36+
}
37+
38+
// CHECK: cir.func @fold_complex_real_from_create_test(%[[ARG_0:.*]]: !s32i, %[[ARG_1:.*]]: !s32i) -> !s32i {
39+
// CHECK-NEXT: cir.return %[[ARG_0]] : !s32i
40+
// CHECK-NEXT: }
41+
3242
cir.func @fold_complex_imag() -> !s32i {
3343
%0 = cir.const #cir.int<1> : !s32i
3444
%1 = cir.const #cir.int<2> : !s32i
@@ -41,4 +51,14 @@ module {
4151
// CHECK-NEXT: %[[#A:]] = cir.const #cir.int<2> : !s32i
4252
// CHECK-NEXT: cir.return %[[#A]] : !s32i
4353
// CHECK-NEXT: }
54+
55+
cir.func @fold_complex_imag_from_create_test(%arg0: !s32i, %arg1: !s32i) -> !s32i {
56+
%0 = cir.complex.create %arg0, %arg1 : !s32i -> !cir.complex<!s32i>
57+
%1 = cir.complex.imag %0 : !cir.complex<!s32i> -> !s32i
58+
cir.return %1 : !s32i
59+
}
60+
61+
// CHECK: cir.func @fold_complex_imag_from_create_test(%[[ARG_0:.*]]: !s32i, %[[ARG_1:.*]]: !s32i) -> !s32i {
62+
// CHECK-NEXT: cir.return %[[ARG_1]] : !s32i
63+
// CHECK-NEXT: }
4464
}

0 commit comments

Comments
 (0)