Skip to content

Commit 3763712

Browse files
authored
[CIR] Update ComplexRealOp to work on scalar type (#161080)
Update cir::CreateRealOp to make it visible on scalars Issue #160568
1 parent 62c50fd commit 3763712

File tree

7 files changed

+46
-17
lines changed

7 files changed

+46
-17
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,10 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
148148
}
149149

150150
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand) {
151-
auto operandTy = mlir::cast<cir::ComplexType>(operand.getType());
152-
return cir::ComplexRealOp::create(*this, loc, operandTy.getElementType(),
153-
operand);
151+
auto resultType = operand.getType();
152+
if (auto complexResultType = mlir::dyn_cast<cir::ComplexType>(resultType))
153+
resultType = complexResultType.getElementType();
154+
return cir::ComplexRealOp::create(*this, loc, resultType, operand);
154155
}
155156

156157
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand) {

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3260,18 +3260,20 @@ def CIR_ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
32603260
def CIR_ComplexRealOp : CIR_Op<"complex.real", [Pure]> {
32613261
let summary = "Extract the real part of a complex value";
32623262
let description = [{
3263-
`cir.complex.real` operation takes an operand of `!cir.complex` type and
3264-
yields the real part of it.
3263+
`cir.complex.real` operation takes an operand of `!cir.complex`, `!cir.int`
3264+
or `!cir.float`. If the operand is `!cir.complex`, the real part of it will
3265+
be returned, otherwise the value returned unmodified.
32653266

32663267
Example:
32673268

32683269
```mlir
3269-
%1 = cir.complex.real %0 : !cir.complex<!cir.float> -> !cir.float
3270+
%real = cir.complex.real %complex : !cir.complex<!cir.float> -> !cir.float
3271+
%real = cir.complex.real %scalar : !cir.float -> !cir.float
32703272
```
32713273
}];
32723274

32733275
let results = (outs CIR_AnyIntOrFloatType:$result);
3274-
let arguments = (ins CIR_ComplexType:$operand);
3276+
let arguments = (ins CIR_AnyComplexOrIntOrFloatType:$operand);
32753277

32763278
let assemblyFormat = [{
32773279
$operand `:` qualified(type($operand)) `->` qualified(type($result))

clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,12 @@ def CIR_AnyIntOrFloatType : AnyTypeOf<[CIR_AnyFloatType, CIR_AnyIntType],
165165

166166
def CIR_AnyComplexType : CIR_TypeBase<"::cir::ComplexType", "complex type">;
167167

168+
def CIR_AnyComplexOrIntOrFloatType : AnyTypeOf<[
169+
CIR_AnyComplexType, CIR_AnyFloatType, CIR_AnyIntType
170+
], "complex, integer or floating point type"> {
171+
let cppFunctionName = "isComplexOrIntegerOrFloatingPointType";
172+
}
173+
168174
//===----------------------------------------------------------------------===//
169175
// Array Type predicates
170176
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2151,8 +2151,10 @@ mlir::Value ScalarExprEmitter::VisitRealImag(const UnaryOperator *e,
21512151
}
21522152

21532153
if (e->getOpcode() == UO_Real) {
2154-
return promotionTy.isNull() ? Visit(op)
2155-
: cgf.emitPromotedScalarExpr(op, promotionTy);
2154+
mlir::Value operand = promotionTy.isNull()
2155+
? Visit(op)
2156+
: cgf.emitPromotedScalarExpr(op, promotionTy);
2157+
return builder.createComplexReal(loc, operand);
21562158
}
21572159

21582160
// __imag on a scalar returns zero. Emit the subexpr to ensure side

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2388,14 +2388,23 @@ OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
23882388
//===----------------------------------------------------------------------===//
23892389

23902390
LogicalResult cir::ComplexRealOp::verify() {
2391-
if (getType() != getOperand().getType().getElementType()) {
2391+
mlir::Type operandTy = getOperand().getType();
2392+
if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy)) {
2393+
operandTy = complexOperandTy.getElementType();
2394+
}
2395+
2396+
if (getType() != operandTy) {
23922397
emitOpError() << ": result type does not match operand type";
23932398
return failure();
23942399
}
2400+
23952401
return success();
23962402
}
23972403

23982404
OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
2405+
if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
2406+
return nullptr;
2407+
23992408
if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
24002409
return complexCreateOp.getOperand(0);
24012410

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2999,8 +2999,13 @@ mlir::LogicalResult CIRToLLVMComplexRealOpLowering::matchAndRewrite(
29992999
cir::ComplexRealOp op, OpAdaptor adaptor,
30003000
mlir::ConversionPatternRewriter &rewriter) const {
30013001
mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
3002-
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
3003-
op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef<std::int64_t>{0});
3002+
mlir::Value operand = adaptor.getOperand();
3003+
if (mlir::isa<cir::ComplexType>(op.getOperand().getType())) {
3004+
operand = mlir::LLVM::ExtractValueOp::create(
3005+
rewriter, op.getLoc(), resultLLVMTy, operand,
3006+
llvm::ArrayRef<std::int64_t>{0});
3007+
}
3008+
rewriter.replaceOp(op, operand);
30043009
return mlir::success();
30053010
}
30063011

clang/test/CIR/CodeGen/complex.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,8 @@ void real_on_scalar_glvalue() {
11401140
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.float, !cir.ptr<!cir.float>, ["a"]
11411141
// CIR: %[[B_ADDR:.*]] = cir.alloca !cir.float, !cir.ptr<!cir.float>, ["b", init]
11421142
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.float>, !cir.float
1143-
// CIR: cir.store{{.*}} %[[TMP_A]], %[[B_ADDR]] : !cir.float, !cir.ptr<!cir.float>
1143+
// CIR: %[[A_REAL:.*]] = cir.complex.real %[[TMP_A]] : !cir.float -> !cir.float
1144+
// CIR: cir.store{{.*}} %[[A_REAL]], %[[B_ADDR]] : !cir.float, !cir.ptr<!cir.float>
11441145

11451146
// LLVM: %[[A_ADDR:.*]] = alloca float, i64 1, align 4
11461147
// LLVM: %[[B_ADDR:.*]] = alloca float, i64 1, align 4
@@ -1179,7 +1180,8 @@ void real_on_scalar_with_type_promotion() {
11791180
// CIR: %[[B_ADDR:.*]] = cir.alloca !cir.f16, !cir.ptr<!cir.f16>, ["b", init]
11801181
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.f16>, !cir.f16
11811182
// CIR: %[[TMP_A_F32:.*]] = cir.cast floating %[[TMP_A]] : !cir.f16 -> !cir.float
1182-
// CIR: %[[TMP_A_F16:.*]] = cir.cast floating %[[TMP_A_F32]] : !cir.float -> !cir.f16
1183+
// CIR: %[[A_REAL:.*]] = cir.complex.real %[[TMP_A_F32]] : !cir.float -> !cir.float
1184+
// CIR: %[[TMP_A_F16:.*]] = cir.cast floating %[[A_REAL]] : !cir.float -> !cir.f16
11831185
// CIR: cir.store{{.*}} %[[TMP_A_F16]], %[[B_ADDR]] : !cir.f16, !cir.ptr<!cir.f16>
11841186

11851187
// LLVM: %[[A_ADDR:.*]] = alloca half, i64 1, align 2
@@ -1248,7 +1250,8 @@ void real_on_scalar_from_real_with_type_promotion() {
12481250
// CIR: %[[A_IMAG_F32:.*]] = cir.cast floating %[[A_IMAG]] : !cir.f16 -> !cir.float
12491251
// CIR: %[[A_COMPLEX_F32:.*]] = cir.complex.create %[[A_REAL_F32]], %[[A_IMAG_F32]] : !cir.float -> !cir.complex<!cir.float>
12501252
// CIR: %[[A_REAL_F32:.*]] = cir.complex.real %[[A_COMPLEX_F32]] : !cir.complex<!cir.float> -> !cir.float
1251-
// CIR: %[[A_REAL_F16:.*]] = cir.cast floating %[[A_REAL_F32]] : !cir.float -> !cir.f16
1253+
// CIR: %[[A_REAL:.*]] = cir.complex.real %[[A_REAL_F32]] : !cir.float -> !cir.float
1254+
// CIR: %[[A_REAL_F16:.*]] = cir.cast floating %[[A_REAL]] : !cir.float -> !cir.f16
12521255
// CIR: cir.store{{.*}} %[[A_REAL_F16]], %[[B_ADDR]] : !cir.f16, !cir.ptr<!cir.f16>
12531256

12541257
// LLVM: %[[A_ADDR:.*]] = alloca { half, half }, i64 1, align 2
@@ -1285,8 +1288,9 @@ void real_on_scalar_from_imag_with_type_promotion() {
12851288
// CIR: %[[A_IMAG_F32:.*]] = cir.cast floating %[[A_IMAG]] : !cir.f16 -> !cir.float
12861289
// CIR: %[[A_COMPLEX_F32:.*]] = cir.complex.create %[[A_REAL_F32]], %[[A_IMAG_F32]] : !cir.float -> !cir.complex<!cir.float>
12871290
// CIR: %[[A_IMAG_F32:.*]] = cir.complex.imag %[[A_COMPLEX_F32]] : !cir.complex<!cir.float> -> !cir.float
1288-
// CIR: %[[A_IMAG_F16:.*]] = cir.cast floating %[[A_IMAG_F32]] : !cir.float -> !cir.f16
1289-
// CIR: cir.store{{.*}} %[[A_IMAG_F16]], %[[B_ADDR]] : !cir.f16, !cir.ptr<!cir.f16>
1291+
// CIR: %[[A_REAL_F32:.*]] = cir.complex.real %[[A_IMAG_F32]] : !cir.float -> !cir.float
1292+
// CIR: %[[A_REAL_F16:.*]] = cir.cast floating %[[A_REAL_F32]] : !cir.float -> !cir.f16
1293+
// CIR: cir.store{{.*}} %[[A_REAL_F16]], %[[B_ADDR]] : !cir.f16, !cir.ptr<!cir.f16>
12901294

12911295
// LLVM: %[[A_ADDR:.*]] = alloca { half, half }, i64 1, align 2
12921296
// LLVM: %[[B_ADDR]] = alloca half, i64 1, align 2

0 commit comments

Comments
 (0)