Skip to content

Commit a368fb5

Browse files
authored
[CIR] Update ComplexImagOp to work on scalar type (#161571)
Update cir::ComplexImagOp to make it visible on scalars Issue #160568
1 parent 9a2a4f6 commit a368fb5

File tree

6 files changed

+52
-29
lines changed

6 files changed

+52
-29
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,10 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
178178
}
179179

180180
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand) {
181-
auto operandTy = mlir::cast<cir::ComplexType>(operand.getType());
182-
return cir::ComplexImagOp::create(*this, loc, operandTy.getElementType(),
183-
operand);
181+
auto resultType = operand.getType();
182+
if (auto complexResultType = mlir::dyn_cast<cir::ComplexType>(resultType))
183+
resultType = complexResultType.getElementType();
184+
return cir::ComplexImagOp::create(*this, loc, resultType, operand);
184185
}
185186

186187
cir::LoadOp createLoad(mlir::Location loc, mlir::Value ptr,

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3308,18 +3308,20 @@ def CIR_ComplexRealOp : CIR_Op<"complex.real", [Pure]> {
33083308
def CIR_ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
33093309
let summary = "Extract the imaginary part of a complex value";
33103310
let description = [{
3311-
`cir.complex.imag` operation takes an operand of `!cir.complex` type and
3312-
yields the imaginary part of it.
3311+
`cir.complex.imag` operation takes an operand of `!cir.complex`, `!cir.int`
3312+
or `!cir.float`. If the operand is `!cir.complex`, the imag part of it will
3313+
be returned, otherwise a zero value will be returned.
33133314

33143315
Example:
33153316

33163317
```mlir
3317-
%1 = cir.complex.imag %0 : !cir.complex<!cir.float> -> !cir.float
3318+
%imag = cir.complex.imag %complex : !cir.complex<!cir.float> -> !cir.float
3319+
%imag = cir.complex.imag %scalar : !cir.float -> !cir.float
33183320
```
33193321
}];
33203322

33213323
let results = (outs CIR_AnyIntOrFloatType:$result);
3322-
let arguments = (ins CIR_ComplexType:$operand);
3324+
let arguments = (ins CIR_AnyComplexOrIntOrFloatType:$operand);
33233325

33243326
let assemblyFormat = [{
33253327
$operand `:` qualified(type($operand)) `->` qualified(type($result))

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2159,16 +2159,16 @@ mlir::Value ScalarExprEmitter::VisitRealImag(const UnaryOperator *e,
21592159

21602160
// __imag on a scalar returns zero. Emit the subexpr to ensure side
21612161
// effects are evaluated, but not the actual value.
2162-
if (op->isGLValue())
2163-
cgf.emitLValue(op);
2164-
else if (!promotionTy.isNull())
2165-
cgf.emitPromotedScalarExpr(op, promotionTy);
2166-
else
2167-
cgf.emitScalarExpr(op);
2168-
2169-
mlir::Type valueTy =
2170-
cgf.convertType(promotionTy.isNull() ? e->getType() : promotionTy);
2171-
return builder.getNullValue(valueTy, loc);
2162+
mlir::Value operand;
2163+
if (op->isGLValue()) {
2164+
operand = cgf.emitLValue(op).getPointer();
2165+
operand = cir::LoadOp::create(builder, loc, operand);
2166+
} else if (!promotionTy.isNull()) {
2167+
operand = cgf.emitPromotedScalarExpr(op, promotionTy);
2168+
} else {
2169+
operand = cgf.emitScalarExpr(op);
2170+
}
2171+
return builder.createComplexImag(loc, operand);
21722172
}
21732173

21742174
/// Return the size or alignment of the type of argument of the sizeof

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2402,9 +2402,8 @@ OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
24022402

24032403
LogicalResult cir::ComplexRealOp::verify() {
24042404
mlir::Type operandTy = getOperand().getType();
2405-
if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy)) {
2405+
if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
24062406
operandTy = complexOperandTy.getElementType();
2407-
}
24082407

24092408
if (getType() != operandTy) {
24102409
emitOpError() << ": result type does not match operand type";
@@ -2431,14 +2430,22 @@ OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
24312430
//===----------------------------------------------------------------------===//
24322431

24332432
LogicalResult cir::ComplexImagOp::verify() {
2434-
if (getType() != getOperand().getType().getElementType()) {
2433+
mlir::Type operandTy = getOperand().getType();
2434+
if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
2435+
operandTy = complexOperandTy.getElementType();
2436+
2437+
if (getType() != operandTy) {
24352438
emitOpError() << ": result type does not match operand type";
24362439
return failure();
24372440
}
2441+
24382442
return success();
24392443
}
24402444

24412445
OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
2446+
if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
2447+
return nullptr;
2448+
24422449
if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
24432450
return complexCreateOp.getOperand(1);
24442451

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3061,8 +3061,19 @@ mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite(
30613061
cir::ComplexImagOp op, OpAdaptor adaptor,
30623062
mlir::ConversionPatternRewriter &rewriter) const {
30633063
mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
3064-
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
3065-
op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef<std::int64_t>{1});
3064+
mlir::Value operand = adaptor.getOperand();
3065+
mlir::Location loc = op.getLoc();
3066+
3067+
if (mlir::isa<cir::ComplexType>(op.getOperand().getType())) {
3068+
operand = mlir::LLVM::ExtractValueOp::create(
3069+
rewriter, loc, resultLLVMTy, operand, llvm::ArrayRef<std::int64_t>{1});
3070+
} else {
3071+
mlir::TypedAttr zeroAttr = rewriter.getZeroAttr(resultLLVMTy);
3072+
operand =
3073+
mlir::LLVM::ConstantOp::create(rewriter, loc, resultLLVMTy, zeroAttr);
3074+
}
3075+
3076+
rewriter.replaceOp(op, operand);
30663077
return mlir::success();
30673078
}
30683079

clang/test/CIR/CodeGen/complex.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,8 +1160,9 @@ void imag_on_scalar_glvalue() {
11601160

11611161
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.float, !cir.ptr<!cir.float>, ["a"]
11621162
// CIR: %[[B_ADDR:.*]] = cir.alloca !cir.float, !cir.ptr<!cir.float>, ["b", init]
1163-
// CIR: %[[CONST_ZERO:.*]] = cir.const #cir.fp<0.000000e+00> : !cir.float
1164-
// CIR: cir.store{{.*}} %[[CONST_ZERO]], %[[B_ADDR]] : !cir.float, !cir.ptr<!cir.float>
1163+
// CIR: %[[TMP_A:.*]] = cir.load %[[A_ADDR]] : !cir.ptr<!cir.float>, !cir.float
1164+
// CIR: %[[A_IMAG:.*]] = cir.complex.imag %[[TMP_A]] : !cir.float -> !cir.float
1165+
// CIR: cir.store{{.*}} %[[A_IMAG]], %[[B_ADDR]] : !cir.float, !cir.ptr<!cir.float>
11651166

11661167
// LLVM: %[[A_ADDR:.*]] = alloca float, i64 1, align 4
11671168
// LLVM: %[[B_ADDR:.*]] = alloca float, i64 1, align 4
@@ -1205,9 +1206,10 @@ void imag_on_scalar_with_type_promotion() {
12051206

12061207
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.f16, !cir.ptr<!cir.f16>, ["a"]
12071208
// CIR: %[[B_ADDR:.*]] = cir.alloca !cir.f16, !cir.ptr<!cir.f16>, ["b", init]
1208-
// CIR: %[[CONST_ZERO:.*]] = cir.const #cir.fp<0.000000e+00> : !cir.float
1209-
// CIR: %[[CONST_ZERO_F16:.*]] = cir.cast floating %[[CONST_ZERO]] : !cir.float -> !cir.f16
1210-
// CIR: cir.store{{.*}} %[[CONST_ZERO_F16]], %[[B_ADDR]] : !cir.f16, !cir.ptr<!cir.f16>
1209+
// CIR: %[[TMP_A:.*]] = cir.load %[[A_ADDR]] : !cir.ptr<!cir.f16>, !cir.f16
1210+
// CIR: %[[A_IMAG:.*]] = cir.complex.imag %[[TMP_A]] : !cir.f16 -> !cir.f16
1211+
// CIR: %[[A_IMAG_F16:.*]] = cir.cast floating %[[A_IMAG]] : !cir.f16 -> !cir.f16
1212+
// CIR: cir.store{{.*}} %[[A_IMAG_F16]], %[[B_ADDR]] : !cir.f16, !cir.ptr<!cir.f16>
12111213

12121214
// LLVM: %[[A_ADDR:.*]] = alloca half, i64 1, align 2
12131215
// LLVM: %[[B_ADDR:.*]] = alloca half, i64 1, align 2
@@ -1225,8 +1227,8 @@ void imag_on_const_scalar() {
12251227
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.float, !cir.ptr<!cir.float>, ["a"]
12261228
// CIR: %[[B_ADDR:.*]] = cir.alloca !cir.float, !cir.ptr<!cir.float>, ["b", init]
12271229
// CIR: %[[CONST_ONE:.*]] = cir.const #cir.fp<1.000000e+00> : !cir.float
1228-
// CIR: %[[CONST_ZERO:.*]] = cir.const #cir.fp<0.000000e+00> : !cir.float
1229-
// CIR: cir.store{{.*}} %[[CONST_ZERO]], %[[B_ADDR]] : !cir.float, !cir.ptr<!cir.float>
1230+
// CIR: %[[CONST_IMAG:.*]] = cir.complex.imag %[[CONST_ONE]] : !cir.float -> !cir.float
1231+
// CIR: cir.store{{.*}} %[[CONST_IMAG]], %[[B_ADDR]] : !cir.float, !cir.ptr<!cir.float>
12301232

12311233
// LLVM: %[[A_ADDR:.*]] = alloca float, i64 1, align 4
12321234
// LLVM: %[[B_ADDR:.*]] = alloca float, i64 1, align 4

0 commit comments

Comments
 (0)