-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[CIR] Update ComplexRealOp to work on scalar type #161080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CIR] Update ComplexRealOp to work on scalar type #161080
Conversation
|
@llvm/pr-subscribers-clang @llvm/pr-subscribers-clangir Author: Amr Hesham (AmrDeveloper) ChangesUpdate cir::CreateRealOp to make it visible on scalars Issue #160568 Full diff: https://github.com/llvm/llvm-project/pull/161080.diff 7 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
index a3f167e3cde2c..7dbb6e42dbb65 100644
--- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
+++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
@@ -148,9 +148,10 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
}
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand) {
- auto operandTy = mlir::cast<cir::ComplexType>(operand.getType());
- return cir::ComplexRealOp::create(*this, loc, operandTy.getElementType(),
- operand);
+ auto resultType = operand.getType();
+ if (mlir::isa<cir::ComplexType>(resultType))
+ resultType = mlir::cast<cir::ComplexType>(resultType).getElementType();
+ return cir::ComplexRealOp::create(*this, loc, resultType, operand);
}
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand) {
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index bb394440bf8d8..eaf5b8958c7b8 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -3245,18 +3245,19 @@ def CIR_ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
def CIR_ComplexRealOp : CIR_Op<"complex.real", [Pure]> {
let summary = "Extract the real part of a complex value";
let description = [{
- `cir.complex.real` operation takes an operand of `!cir.complex` type and
- yields the real part of it.
+ `cir.complex.real` operation takes an operand of `!cir.complex` or scalar
+ type and yields the real part of it.
Example:
```mlir
- %1 = cir.complex.real %0 : !cir.complex<!cir.float> -> !cir.float
+ %real = cir.complex.real %complex : !cir.complex<!cir.float> -> !cir.float
+ %real = cir.complex.real %scalar : !cir.float -> !cir.float
```
}];
let results = (outs CIR_AnyIntOrFloatType:$result);
- let arguments = (ins CIR_ComplexType:$operand);
+ let arguments = (ins CIR_AnyComplexOrIntOrFloatType:$operand);
let assemblyFormat = [{
$operand `:` qualified(type($operand)) `->` qualified(type($result))
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
index 82f6e1d33043e..da03a291a7690 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
@@ -165,6 +165,12 @@ def CIR_AnyIntOrFloatType : AnyTypeOf<[CIR_AnyFloatType, CIR_AnyIntType],
def CIR_AnyComplexType : CIR_TypeBase<"::cir::ComplexType", "complex type">;
+def CIR_AnyComplexOrIntOrFloatType : AnyTypeOf<[
+ CIR_AnyComplexType, CIR_AnyFloatType, CIR_AnyIntType
+], "complex, integer or floating point type"> {
+ let cppFunctionName = "isComplexOrIntegerOrFloatingPointType";
+}
+
//===----------------------------------------------------------------------===//
// Array Type predicates
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index bd09d78cd0eb6..f8dcae042740a 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -2146,8 +2146,10 @@ mlir::Value ScalarExprEmitter::VisitRealImag(const UnaryOperator *e,
}
if (e->getOpcode() == UO_Real) {
- return promotionTy.isNull() ? Visit(op)
- : cgf.emitPromotedScalarExpr(op, promotionTy);
+ mlir::Value operand = promotionTy.isNull()
+ ? Visit(op)
+ : cgf.emitPromotedScalarExpr(op, promotionTy);
+ return builder.createComplexReal(loc, operand);
}
// __imag on a scalar returns zero. Emit the subexpr to ensure side
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 58ef500446aa7..c3dda37db37cf 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -2302,14 +2302,23 @@ OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
LogicalResult cir::ComplexRealOp::verify() {
- if (getType() != getOperand().getType().getElementType()) {
+ mlir::Type operandTy = getOperand().getType();
+ if (mlir::isa<cir::ComplexType>(operandTy)) {
+ operandTy = mlir::cast<cir::ComplexType>(operandTy).getElementType();
+ }
+
+ if (getType() != operandTy) {
emitOpError() << ": result type does not match operand type";
return failure();
}
+
return success();
}
OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
+ if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
+ return nullptr;
+
if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
return complexCreateOp.getOperand(0);
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 876948d53010b..664bf55d2b7f5 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -2992,8 +2992,13 @@ mlir::LogicalResult CIRToLLVMComplexRealOpLowering::matchAndRewrite(
cir::ComplexRealOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
- rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
- op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef<std::int64_t>{0});
+ mlir::Value operand = adaptor.getOperand();
+ if (mlir::isa<cir::ComplexType>(op.getOperand().getType())) {
+ operand = mlir::LLVM::ExtractValueOp::create(
+ rewriter, op.getLoc(), resultLLVMTy, operand,
+ llvm::ArrayRef<std::int64_t>{0});
+ }
+ rewriter.replaceOp(op, operand);
return mlir::success();
}
diff --git a/clang/test/CIR/CodeGen/complex.cpp b/clang/test/CIR/CodeGen/complex.cpp
index e90163172d2df..abb3431e5f37d 100644
--- a/clang/test/CIR/CodeGen/complex.cpp
+++ b/clang/test/CIR/CodeGen/complex.cpp
@@ -1140,7 +1140,8 @@ void real_on_scalar_glvalue() {
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.float, !cir.ptr<!cir.float>, ["a"]
// CIR: %[[B_ADDR:.*]] = cir.alloca !cir.float, !cir.ptr<!cir.float>, ["b", init]
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.float>, !cir.float
-// CIR: cir.store{{.*}} %[[TMP_A]], %[[B_ADDR]] : !cir.float, !cir.ptr<!cir.float>
+// CIR: %[[A_REAL:.*]] = cir.complex.real %2 : !cir.float -> !cir.float
+// CIR: cir.store{{.*}} %[[A_REAL]], %[[B_ADDR]] : !cir.float, !cir.ptr<!cir.float>
// LLVM: %[[A_ADDR:.*]] = alloca float, i64 1, align 4
// LLVM: %[[B_ADDR:.*]] = alloca float, i64 1, align 4
@@ -1179,7 +1180,8 @@ void real_on_scalar_with_type_promotion() {
// CIR: %[[B_ADDR:.*]] = cir.alloca !cir.f16, !cir.ptr<!cir.f16>, ["b", init]
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.f16>, !cir.f16
// CIR: %[[TMP_A_F32:.*]] = cir.cast(floating, %[[TMP_A]] : !cir.f16), !cir.float
-// CIR: %[[TMP_A_F16:.*]] = cir.cast(floating, %[[TMP_A_F32]] : !cir.float), !cir.f16
+// CIR: %[[A_REAL:.*]] = cir.complex.real %[[TMP_A_F32]] : !cir.float -> !cir.float
+// CIR: %[[TMP_A_F16:.*]] = cir.cast(floating, %[[A_REAL]] : !cir.float), !cir.f16
// CIR: cir.store{{.*}} %[[TMP_A_F16]], %[[B_ADDR]] : !cir.f16, !cir.ptr<!cir.f16>
// LLVM: %[[A_ADDR:.*]] = alloca half, i64 1, align 2
@@ -1248,7 +1250,8 @@ void real_on_scalar_from_real_with_type_promotion() {
// CIR: %[[A_IMAG_F32:.*]] = cir.cast(floating, %[[A_IMAG]] : !cir.f16), !cir.float
// CIR: %[[A_COMPLEX_F32:.*]] = cir.complex.create %[[A_REAL_F32]], %[[A_IMAG_F32]] : !cir.float -> !cir.complex<!cir.float>
// CIR: %[[A_REAL_F32:.*]] = cir.complex.real %[[A_COMPLEX_F32]] : !cir.complex<!cir.float> -> !cir.float
-// CIR: %[[A_REAL_F16:.*]] = cir.cast(floating, %[[A_REAL_F32]] : !cir.float), !cir.f16
+// CIR: %[[A_REAL:.*]] = cir.complex.real %[[A_REAL_F32]] : !cir.float -> !cir.float
+// CIR: %[[A_REAL_F16:.*]] = cir.cast(floating, %[[A_REAL]] : !cir.float), !cir.f16
// CIR: cir.store{{.*}} %[[A_REAL_F16]], %[[B_ADDR]] : !cir.f16, !cir.ptr<!cir.f16>
// LLVM: %[[A_ADDR:.*]] = alloca { half, half }, i64 1, align 2
|
andykaylor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good except for a couple of nits
| `cir.complex.real` operation takes an operand of `!cir.complex` type and | ||
| yields the real part of it. | ||
| `cir.complex.real` operation takes an operand of `!cir.complex`, `!cir.int` | ||
| or `!cir.float` type and yields the real part of it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This wording is awkward. What's the "real" part of a float or int? Please reword this to say that if the operand is complex, the operation returns the real component of the complex value, otherwise the value is returned unmodfied.
57ce405 to
10522fd
Compare
Update cir::CreateRealOp to make it visible on scalars Issue llvm#160568
Update cir::CreateRealOp to make it visible on scalars
Issue #160568