-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[CIR] Upstream MulOp for ComplexType #150834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-clang Author: Amr Hesham (AmrDeveloper) ChangesThis change adds support for mul op for ComplexType Patch is 39.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150834.diff 5 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
index 5c04d59475b6a..aea8315c274a5 100644
--- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
+++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
@@ -441,6 +441,10 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return create<cir::CmpOp>(loc, getBoolTy(), kind, lhs, rhs);
}
+ mlir::Value createIsNaN(mlir::Location loc, mlir::Value operand) {
+ return createCompare(loc, cir::CmpOpKind::ne, operand, operand);
+ }
+
mlir::Value createShift(mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
bool isShiftLeft) {
return create<cir::ShiftOp>(loc, lhs.getType(), lhs, rhs, isShiftLeft);
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index e2ddbd12c77bd..eb724dd299d3b 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2804,6 +2804,47 @@ def CIR_ComplexSubOp : CIR_Op<"complex.sub", [
}];
}
+//===----------------------------------------------------------------------===//
+// ComplexMulOp
+//===----------------------------------------------------------------------===//
+
+def CIR_ComplexRangeKind : CIR_I32EnumAttr<
+ "ComplexRangeKind", "complex multiplication and division implementation", [
+ I32EnumAttrCase<"Full", 0, "full">,
+ I32EnumAttrCase<"Improved", 1, "improved">,
+ I32EnumAttrCase<"Promoted", 2, "promoted">,
+ I32EnumAttrCase<"Basic", 3, "basic">,
+ I32EnumAttrCase<"None", 4, "none">
+]>;
+
+def CIR_ComplexMulOp : CIR_Op<"complex.mul", [
+ Pure, SameOperandsAndResultType
+]> {
+ let summary = "Complex multiplication";
+ let description = [{
+ The `cir.complex.mul` operation takes two complex numbers and returns
+ their product.
+
+ Example:
+
+ ```mlir
+ %2 = cir.complex.mul %0, %1 range(basic) : !cir.complex<!cir.float>
+ %2 = cir.complex.mul %0, %1 range(full) : !cir.complex<!cir.float>
+ ```
+ }];
+
+ let arguments = (ins
+ CIR_ComplexType:$lhs,
+ CIR_ComplexType:$rhs,
+ CIR_ComplexRangeKind:$range
+ );
+
+ let results = (outs CIR_ComplexType:$result);
+
+ let assemblyFormat = [{
+ $lhs `,` $rhs `range` `(` $range `)` `:` qualified(type($result)) attr-dict
+ }];
+}
//===----------------------------------------------------------------------===//
// Bit Manipulation Operations
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp b/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
index 02685a3d64121..196df0e9e6405 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
@@ -110,6 +110,7 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
mlir::Value emitBinAdd(const BinOpInfo &op);
mlir::Value emitBinSub(const BinOpInfo &op);
+ mlir::Value emitBinMul(const BinOpInfo &op);
QualType getPromotionType(QualType ty, bool isDivOpCode = false) {
if (auto *complexTy = ty->getAs<ComplexType>()) {
@@ -142,6 +143,7 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
HANDLEBINOP(Add)
HANDLEBINOP(Sub)
+ HANDLEBINOP(Mul)
#undef HANDLEBINOP
};
} // namespace
@@ -531,6 +533,7 @@ mlir::Value ComplexExprEmitter::emitPromoted(const Expr *e,
return emitBin##OP(emitBinOps(bo, promotionTy));
HANDLE_BINOP(Add)
HANDLE_BINOP(Sub)
+ HANDLE_BINOP(Mul)
#undef HANDLE_BINOP
default:
break;
@@ -582,6 +585,30 @@ mlir::Value ComplexExprEmitter::emitBinSub(const BinOpInfo &op) {
return builder.create<cir::ComplexSubOp>(op.loc, op.lhs, op.rhs);
}
+static cir::ComplexRangeKind
+getComplexRangeAttr(LangOptions::ComplexRangeKind range) {
+ switch (range) {
+ case LangOptions::CX_Full:
+ return cir::ComplexRangeKind::Full;
+ case LangOptions::CX_Improved:
+ return cir::ComplexRangeKind::Improved;
+ case LangOptions::CX_Promoted:
+ return cir::ComplexRangeKind::Promoted;
+ case LangOptions::CX_Basic:
+ return cir::ComplexRangeKind::Basic;
+ case LangOptions::CX_None:
+ return cir::ComplexRangeKind::None;
+ }
+}
+
+mlir::Value ComplexExprEmitter::emitBinMul(const BinOpInfo &op) {
+ assert(!cir::MissingFeatures::fastMathFlags());
+ assert(!cir::MissingFeatures::cgFPOptionsRAII());
+ cir::ComplexRangeKind rangeKind =
+ getComplexRangeAttr(op.fpFeatures.getComplexRange());
+ return builder.create<cir::ComplexMulOp>(op.loc, op.lhs, op.rhs, rangeKind);
+}
+
LValue CIRGenFunction::emitComplexAssignmentLValue(const BinaryOperator *e) {
assert(e->getOpcode() == BO_Assign && "Expected assign op");
diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
index cef83eae2ef50..a2ce550819183 100644
--- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
@@ -15,7 +15,6 @@
#include "clang/CIR/Dialect/Passes.h"
#include "clang/CIR/MissingFeatures.h"
-#include <iostream>
#include <memory>
using namespace mlir;
@@ -28,20 +27,46 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
void runOnOp(mlir::Operation *op);
void lowerCastOp(cir::CastOp op);
+ void lowerComplexMulOp(cir::ComplexMulOp op);
void lowerUnaryOp(cir::UnaryOp op);
void lowerArrayCtor(ArrayCtor op);
+ cir::FuncOp buildRuntimeFunction(
+ mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
+ cir::FuncType type,
+ cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
+
///
/// AST related
/// -----------
clang::ASTContext *astCtx;
+ /// Tracks current module.
+ mlir::ModuleOp theModule;
+
void setASTContext(clang::ASTContext *c) { astCtx = c; }
};
} // namespace
+cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
+ mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
+ cir::FuncType type, cir::GlobalLinkageKind linkage) {
+ cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
+ theModule, StringAttr::get(theModule->getContext(), name)));
+ if (!f) {
+ f = builder.create<cir::FuncOp>(loc, name, type);
+ f.setLinkageAttr(
+ cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
+ mlir::SymbolTable::setSymbolVisibility(
+ f, mlir::SymbolTable::Visibility::Private);
+
+ assert(!cir::MissingFeatures::opFuncExtraAttrs());
+ }
+ return f;
+}
+
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
cir::CastOp op) {
cir::CIRBaseBuilderTy builder(ctx);
@@ -127,6 +152,120 @@ void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
}
}
+static mlir::Value buildComplexBinOpLibCall(
+ LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
+ llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
+ mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
+ mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
+ cir::FPTypeInterface elementTy =
+ mlir::cast<cir::FPTypeInterface>(ty.getElementType());
+
+ llvm::StringRef libFuncName = libFuncNameGetter(
+ llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
+ llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy);
+
+ cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);
+
+ cir::FuncOp libFunc;
+ {
+ mlir::OpBuilder::InsertionGuard ipGuard{builder};
+ builder.setInsertionPointToStart(pass.theModule.getBody());
+ libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
+ }
+
+ cir::CallOp call =
+ builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
+ return call.getResult();
+}
+
+static llvm::StringRef
+getComplexMulLibCallName(llvm::APFloat::Semantics semantics) {
+ switch (semantics) {
+ case llvm::APFloat::S_IEEEhalf:
+ return "__mulhc3";
+ case llvm::APFloat::S_IEEEsingle:
+ return "__mulsc3";
+ case llvm::APFloat::S_IEEEdouble:
+ return "__muldc3";
+ case llvm::APFloat::S_PPCDoubleDouble:
+ return "__multc3";
+ case llvm::APFloat::S_x87DoubleExtended:
+ return "__mulxc3";
+ case llvm::APFloat::S_IEEEquad:
+ return "__multc3";
+ default:
+ llvm_unreachable("unsupported floating point type");
+ }
+}
+
+static mlir::Value lowerComplexMul(LoweringPreparePass &pass,
+ CIRBaseBuilderTy &builder,
+ mlir::Location loc, cir::ComplexMulOp op,
+ mlir::Value lhsReal, mlir::Value lhsImag,
+ mlir::Value rhsReal, mlir::Value rhsImag) {
+ // (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
+ mlir::Value resultRealLhs =
+ builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsReal);
+ mlir::Value resultRealRhs =
+ builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsImag);
+ mlir::Value resultImagLhs =
+ builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsImag);
+ mlir::Value resultImagRhs =
+ builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsReal);
+ mlir::Value resultReal = builder.createBinop(
+ loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs);
+ mlir::Value resultImag = builder.createBinop(
+ loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs);
+ mlir::Value algebraicResult =
+ builder.createComplexCreate(loc, resultReal, resultImag);
+
+ cir::ComplexType complexTy = op.getType();
+ cir::ComplexRangeKind rangeKind = op.getRange();
+ if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
+ rangeKind == cir::ComplexRangeKind::Basic ||
+ rangeKind == cir::ComplexRangeKind::Improved ||
+ rangeKind == cir::ComplexRangeKind::Promoted)
+ return algebraicResult;
+
+ // Check whether the real part and the imaginary part of the result are both
+ // NaN. If so, emit a library call to compute the multiplication instead.
+ // We check a value against NaN by comparing the value against itself.
+ mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal);
+ mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag);
+ mlir::Value resultRealAndImagAreNaN =
+ builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN);
+
+ return builder
+ .create<cir::TernaryOp>(
+ loc, resultRealAndImagAreNaN,
+ [&](mlir::OpBuilder &, mlir::Location) {
+ mlir::Value libCallResult = buildComplexBinOpLibCall(
+ pass, builder, &getComplexMulLibCallName, loc, complexTy,
+ lhsReal, lhsImag, rhsReal, rhsImag);
+ builder.createYield(loc, libCallResult);
+ },
+ [&](mlir::OpBuilder &, mlir::Location) {
+ builder.createYield(loc, algebraicResult);
+ })
+ .getResult();
+}
+
+void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
+ cir::CIRBaseBuilderTy builder(getContext());
+ builder.setInsertionPointAfter(op);
+ mlir::Location loc = op.getLoc();
+ mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
+ mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
+ mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
+ mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
+ mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
+ mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
+ mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal,
+ lhsImag, rhsReal, rhsImag);
+ op.replaceAllUsesWith(loweredResult);
+ op.erase();
+}
+
void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) {
mlir::Type ty = op.getType();
if (!mlir::isa<cir::ComplexType>(ty))
@@ -247,17 +386,23 @@ void LoweringPreparePass::runOnOp(mlir::Operation *op) {
lowerArrayCtor(arrayCtor);
else if (auto cast = mlir::dyn_cast<cir::CastOp>(op))
lowerCastOp(cast);
+ else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op))
+ lowerComplexMulOp(complexMul);
else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op))
lowerUnaryOp(unary);
}
void LoweringPreparePass::runOnOperation() {
mlir::Operation *op = getOperation();
+ if (isa<::mlir::ModuleOp>(op)) {
+ theModule = cast<::mlir::ModuleOp>(op);
+ }
llvm::SmallVector<mlir::Operation *> opsToTransform;
op->walk([&](mlir::Operation *op) {
- if (mlir::isa<cir::ArrayCtor, cir::CastOp, cir::UnaryOp>(op))
+ if (mlir::isa<cir::ArrayCtor, cir::CastOp, cir::ComplexMulOp, cir::UnaryOp>(
+ op))
opsToTransform.push_back(op);
});
diff --git a/clang/test/CIR/CodeGen/complex-mul-div.cpp b/clang/test/CIR/CodeGen/complex-mul-div.cpp
new file mode 100644
index 0000000000000..5fe682ac28b53
--- /dev/null
+++ b/clang/test/CIR/CodeGen/complex-mul-div.cpp
@@ -0,0 +1,325 @@
+// complex-range basic
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -complex-range=basic -Wno-unused-value -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-canonicalize -o %t.cir %s 2>&1 | FileCheck --check-prefix=CIR-BEFORE-BASIC %s
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=basic -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
+// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR-AFTER-BASIC
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=basic -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
+// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM-BASIC
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=basic -Wno-unused-value -emit-llvm %s -o %t.ll
+// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG-BASIC
+
+// complex-range improved
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -complex-range=improved -Wno-unused-value -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-canonicalize -o %t.cir %s 2>&1 | FileCheck --check-prefix=CIR-BEFORE-IMPROVED %s
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=improved -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
+// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR-AFTER-IMPROVED
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=improved -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
+// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM-IMPROVED
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=improved -Wno-unused-value -emit-llvm %s -o %t.ll
+// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG-IMPROVED
+
+// complex-range promoted
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -complex-range=promoted -Wno-unused-value -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-canonicalize -o %t.cir %s 2>&1 | FileCheck --check-prefix=CIR-BEFORE-PROMOTED %s
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=promoted -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
+// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR-AFTER-PROMOTED
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=promoted -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
+// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM-PROMOTED
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=promoted -Wno-unused-value -emit-llvm %s -o %t.ll
+// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG-PROMOTED
+
+// complex-range full
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -complex-range=full -Wno-unused-value -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-canonicalize -o %t.cir %s 2>&1 | FileCheck --check-prefix=CIR-BEFORE-FULL %s
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=full -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
+// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR-AFTER-FULL
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=full -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
+// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM-FULL
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=full -Wno-unused-value -emit-llvm %s -o %t.ll
+// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG-FULL
+
+void foo() {
+ float _Complex a;
+ float _Complex b;
+ float _Complex c = a * b;
+}
+
+// CIR-BEFORE-BASIC: %{{.*}} = cir.complex.mul {{.*}}, {{.*}} range(basic) : !cir.complex<!cir.float>
+
+// CIR-AFTER-BASIC: %[[A_ADDR:.*]] = cir.alloca !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>, ["a"]
+// CIR-AFTER-BASIC: %[[B_ADDR:.*]] = cir.alloca !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>, ["b"]
+// CIR-AFTER-BASIC: %[[C_ADDR:.*]] = cir.alloca !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>, ["c", init]
+// CIR-AFTER-BASIC: %[[TMP_A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.complex<!cir.float>>, !cir.complex<!cir.float>
+// CIR-AFTER-BASIC: %[[TMP_B:.*]] = cir.load{{.*}} %[[B_ADDR]] : !cir.ptr<!cir.complex<!cir.float>>, !cir.complex<!cir.float>
+// CIR-AFTER-BASIC: %[[A_REAL:.*]] = cir.complex.real %[[TMP_A]] : !cir.complex<!cir.float> -> !cir.float
+// CIR-AFTER-BASIC: %[[A_IMAG:.*]] = cir.complex.imag %[[TMP_A]] : !cir.complex<!cir.float> -> !cir.float
+// CIR-AFTER-BASIC: %[[B_REAL:.*]] = cir.complex.real %[[TMP_B]] : !cir.complex<!cir.float> -> !cir.float
+// CIR-AFTER-BASIC: %[[B_IMAG:.*]] = cir.complex.imag %[[TMP_B]] : !cir.complex<!cir.float> -> !cir.float
+// CIR-AFTER-BASIC: %[[MUL_AR_BR:.*]] = cir.binop(mul, %[[A_REAL]], %[[B_REAL]]) : !cir.float
+// CIR-AFTER-BASIC: %[[MUL_AI_BI:.*]] = cir.binop(mul, %[[A_IMAG]], %[[B_IMAG]]) : !cir.float
+// CIR-AFTER-BASIC: %[[MUL_AR_BI:.*]] = cir.binop(mul, %[[A_REAL]], %[[B_IMAG]]) : !cir.float
+// CIR-AFTER-BASIC: %[[MUL_AI_BR:.*]] = cir.binop(mul, %[[A_IMAG]], %[[B_REAL]]) : !cir.float
+// CIR-AFTER-BASIC: %[[C_REAL:.*]] = cir.binop(sub, %[[MUL_AR_BR]], %[[MUL_AI_BI]]) : !cir.float
+// CIR-AFTER-BASIC: %[[C_IMAG:.*]] = cir.binop(add, %[[MUL_AR_BI]], %[[MUL_AI_BR]]) : !cir.float
+// CIR-AFTER-BASIC: %[[RESULT:.*]] = cir.complex.create %[[C_REAL]], %[[C_IMAG]] : !cir.float -> !cir.complex<!cir.float>
+// CIR-AFTER-BASIC: cir.store{{.*}} %[[RESULT]], %[[C_ADDR]] : !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>
+
+// LLVM-BASIC: %[[A_ADDR:.*]] = alloca { float, float }, i64 1, align 4
+// LLVM-BASIC: %[[B_ADDR:.*]] = alloca { float, float }, i64 1, align 4
+// LLVM-BASIC: %[[C_ADDR:.*]] = alloca { float, float }, i64 1, align 4
+// LLVM-BASIC: %[[TMP_A:.*]] = load { float, float }, ptr %[[A_ADDR]], align 4
+// LLVM-BASIC: %[[TMP_B:.*]] = load { float, float }, ptr %[[B_ADDR]], align 4
+// LLVM-BASIC: %[[A_REAL:.*]] = extractvalue { float, float } %[[TMP_A]], 0
+// LLVM-BASIC: %[[A_IMAG:.*]] = extractvalue { float, float } %[[TMP_A]], 1
+// LLVM-BASIC: %[[B_REAL:.*]] = extractvalue { float, float } %[[TMP_B]], 0
+// LLVM-BASIC: %[[B_IMAG:.*]] = extractvalue { float, float } %[[TMP_B]], 1
+// LLVM-BASIC: %[[MUL_AR_BR:.*]] = fmul float %[[A_REAL]], %[[B_REAL]]
+// LLVM-BASIC: %[[MUL_AI_BI:.*]] = fmul float %[[A_IMAG]], %[[B_IMAG]]
+// LLVM-BASIC: %[[MUL_AR_BI:.*]] = fmul float %[[A_REAL]], %[[B_IMAG]]
+// LLVM-BASIC: %[[MUL_AI_BR:.*]] = fmul float %[[A_IMAG]], %[[B_REAL]]
+// LLVM-BASIC: %[[C_REAL:.*]] = fsub float %[[MUL_AR_BR]], %[[MUL_AI_BI]]
+// LLVM-BASIC: %[[C_IMAG:.*]] = fadd float %[[MUL_AR_BI]], %[[MUL_AI_BR]]
+// LLVM-BASIC: %[[MUL_A_B:.*]] = insertvalue { float, float } {{.*}}, float %[[C_REAL]], 0
+// LLVM-BASIC: %[[RESULT:.*]] = insertvalue { float, float } %[[MUL_A_B]], float %[[C_IMAG]], 1
+// LLVM-BASIC: store { float, float } %[[RESULT]], ptr %[[C_ADDR]], align 4
+
+// OGCG-BASIC: %[[A_ADDR:.*]] = alloca { float, float }, align 4
+// OGCG-BASIC: %[[B_ADDR:.*]] = alloca { float, float }, align 4
+// OGCG-BASIC: %[[C_ADDR:.*]] = alloca { float, float }, align 4
+// OGCG-BASIC: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { float, float }, ptr %[[A_ADDR]], i32 0, i32 0
+// OGCG-BASIC: %[[A_REAL:.*]] = load float, ptr %[[A_REAL_PTR]], align 4
+// OGCG-BASIC: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { floa...
[truncated]
|
|
@llvm/pr-subscribers-clangir Author: Amr Hesham (AmrDeveloper) ChangesThis change adds support for mul op for ComplexType Patch is 39.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150834.diff 5 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
index 5c04d59475b6a..aea8315c274a5 100644
--- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
+++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
@@ -441,6 +441,10 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return create<cir::CmpOp>(loc, getBoolTy(), kind, lhs, rhs);
}
+ mlir::Value createIsNaN(mlir::Location loc, mlir::Value operand) {
+ return createCompare(loc, cir::CmpOpKind::ne, operand, operand);
+ }
+
mlir::Value createShift(mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
bool isShiftLeft) {
return create<cir::ShiftOp>(loc, lhs.getType(), lhs, rhs, isShiftLeft);
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index e2ddbd12c77bd..eb724dd299d3b 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2804,6 +2804,47 @@ def CIR_ComplexSubOp : CIR_Op<"complex.sub", [
}];
}
+//===----------------------------------------------------------------------===//
+// ComplexMulOp
+//===----------------------------------------------------------------------===//
+
+def CIR_ComplexRangeKind : CIR_I32EnumAttr<
+ "ComplexRangeKind", "complex multiplication and division implementation", [
+ I32EnumAttrCase<"Full", 0, "full">,
+ I32EnumAttrCase<"Improved", 1, "improved">,
+ I32EnumAttrCase<"Promoted", 2, "promoted">,
+ I32EnumAttrCase<"Basic", 3, "basic">,
+ I32EnumAttrCase<"None", 4, "none">
+]>;
+
+def CIR_ComplexMulOp : CIR_Op<"complex.mul", [
+ Pure, SameOperandsAndResultType
+]> {
+ let summary = "Complex multiplication";
+ let description = [{
+ The `cir.complex.mul` operation takes two complex numbers and returns
+ their product.
+
+ Example:
+
+ ```mlir
+ %2 = cir.complex.mul %0, %1 range(basic) : !cir.complex<!cir.float>
+ %2 = cir.complex.mul %0, %1 range(full) : !cir.complex<!cir.float>
+ ```
+ }];
+
+ let arguments = (ins
+ CIR_ComplexType:$lhs,
+ CIR_ComplexType:$rhs,
+ CIR_ComplexRangeKind:$range
+ );
+
+ let results = (outs CIR_ComplexType:$result);
+
+ let assemblyFormat = [{
+ $lhs `,` $rhs `range` `(` $range `)` `:` qualified(type($result)) attr-dict
+ }];
+}
//===----------------------------------------------------------------------===//
// Bit Manipulation Operations
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp b/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
index 02685a3d64121..196df0e9e6405 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
@@ -110,6 +110,7 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
mlir::Value emitBinAdd(const BinOpInfo &op);
mlir::Value emitBinSub(const BinOpInfo &op);
+ mlir::Value emitBinMul(const BinOpInfo &op);
QualType getPromotionType(QualType ty, bool isDivOpCode = false) {
if (auto *complexTy = ty->getAs<ComplexType>()) {
@@ -142,6 +143,7 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
HANDLEBINOP(Add)
HANDLEBINOP(Sub)
+ HANDLEBINOP(Mul)
#undef HANDLEBINOP
};
} // namespace
@@ -531,6 +533,7 @@ mlir::Value ComplexExprEmitter::emitPromoted(const Expr *e,
return emitBin##OP(emitBinOps(bo, promotionTy));
HANDLE_BINOP(Add)
HANDLE_BINOP(Sub)
+ HANDLE_BINOP(Mul)
#undef HANDLE_BINOP
default:
break;
@@ -582,6 +585,30 @@ mlir::Value ComplexExprEmitter::emitBinSub(const BinOpInfo &op) {
return builder.create<cir::ComplexSubOp>(op.loc, op.lhs, op.rhs);
}
+static cir::ComplexRangeKind
+getComplexRangeAttr(LangOptions::ComplexRangeKind range) {
+ switch (range) {
+ case LangOptions::CX_Full:
+ return cir::ComplexRangeKind::Full;
+ case LangOptions::CX_Improved:
+ return cir::ComplexRangeKind::Improved;
+ case LangOptions::CX_Promoted:
+ return cir::ComplexRangeKind::Promoted;
+ case LangOptions::CX_Basic:
+ return cir::ComplexRangeKind::Basic;
+ case LangOptions::CX_None:
+ return cir::ComplexRangeKind::None;
+ }
+}
+
+mlir::Value ComplexExprEmitter::emitBinMul(const BinOpInfo &op) {
+ assert(!cir::MissingFeatures::fastMathFlags());
+ assert(!cir::MissingFeatures::cgFPOptionsRAII());
+ cir::ComplexRangeKind rangeKind =
+ getComplexRangeAttr(op.fpFeatures.getComplexRange());
+ return builder.create<cir::ComplexMulOp>(op.loc, op.lhs, op.rhs, rangeKind);
+}
+
LValue CIRGenFunction::emitComplexAssignmentLValue(const BinaryOperator *e) {
assert(e->getOpcode() == BO_Assign && "Expected assign op");
diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
index cef83eae2ef50..a2ce550819183 100644
--- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
@@ -15,7 +15,6 @@
#include "clang/CIR/Dialect/Passes.h"
#include "clang/CIR/MissingFeatures.h"
-#include <iostream>
#include <memory>
using namespace mlir;
@@ -28,20 +27,46 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
void runOnOp(mlir::Operation *op);
void lowerCastOp(cir::CastOp op);
+ void lowerComplexMulOp(cir::ComplexMulOp op);
void lowerUnaryOp(cir::UnaryOp op);
void lowerArrayCtor(ArrayCtor op);
+ cir::FuncOp buildRuntimeFunction(
+ mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
+ cir::FuncType type,
+ cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
+
///
/// AST related
/// -----------
clang::ASTContext *astCtx;
+ /// Tracks current module.
+ mlir::ModuleOp theModule;
+
void setASTContext(clang::ASTContext *c) { astCtx = c; }
};
} // namespace
+cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
+ mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
+ cir::FuncType type, cir::GlobalLinkageKind linkage) {
+ cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
+ theModule, StringAttr::get(theModule->getContext(), name)));
+ if (!f) {
+ f = builder.create<cir::FuncOp>(loc, name, type);
+ f.setLinkageAttr(
+ cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
+ mlir::SymbolTable::setSymbolVisibility(
+ f, mlir::SymbolTable::Visibility::Private);
+
+ assert(!cir::MissingFeatures::opFuncExtraAttrs());
+ }
+ return f;
+}
+
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
cir::CastOp op) {
cir::CIRBaseBuilderTy builder(ctx);
@@ -127,6 +152,120 @@ void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
}
}
+static mlir::Value buildComplexBinOpLibCall(
+ LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
+ llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
+ mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
+ mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
+ cir::FPTypeInterface elementTy =
+ mlir::cast<cir::FPTypeInterface>(ty.getElementType());
+
+ llvm::StringRef libFuncName = libFuncNameGetter(
+ llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
+ llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy);
+
+ cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);
+
+ cir::FuncOp libFunc;
+ {
+ mlir::OpBuilder::InsertionGuard ipGuard{builder};
+ builder.setInsertionPointToStart(pass.theModule.getBody());
+ libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
+ }
+
+ cir::CallOp call =
+ builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
+ return call.getResult();
+}
+
+static llvm::StringRef
+getComplexMulLibCallName(llvm::APFloat::Semantics semantics) {
+ switch (semantics) {
+ case llvm::APFloat::S_IEEEhalf:
+ return "__mulhc3";
+ case llvm::APFloat::S_IEEEsingle:
+ return "__mulsc3";
+ case llvm::APFloat::S_IEEEdouble:
+ return "__muldc3";
+ case llvm::APFloat::S_PPCDoubleDouble:
+ return "__multc3";
+ case llvm::APFloat::S_x87DoubleExtended:
+ return "__mulxc3";
+ case llvm::APFloat::S_IEEEquad:
+ return "__multc3";
+ default:
+ llvm_unreachable("unsupported floating point type");
+ }
+}
+
+static mlir::Value lowerComplexMul(LoweringPreparePass &pass,
+ CIRBaseBuilderTy &builder,
+ mlir::Location loc, cir::ComplexMulOp op,
+ mlir::Value lhsReal, mlir::Value lhsImag,
+ mlir::Value rhsReal, mlir::Value rhsImag) {
+ // (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
+ mlir::Value resultRealLhs =
+ builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsReal);
+ mlir::Value resultRealRhs =
+ builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsImag);
+ mlir::Value resultImagLhs =
+ builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsImag);
+ mlir::Value resultImagRhs =
+ builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsReal);
+ mlir::Value resultReal = builder.createBinop(
+ loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs);
+ mlir::Value resultImag = builder.createBinop(
+ loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs);
+ mlir::Value algebraicResult =
+ builder.createComplexCreate(loc, resultReal, resultImag);
+
+ cir::ComplexType complexTy = op.getType();
+ cir::ComplexRangeKind rangeKind = op.getRange();
+ if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
+ rangeKind == cir::ComplexRangeKind::Basic ||
+ rangeKind == cir::ComplexRangeKind::Improved ||
+ rangeKind == cir::ComplexRangeKind::Promoted)
+ return algebraicResult;
+
+ // Check whether the real part and the imaginary part of the result are both
+ // NaN. If so, emit a library call to compute the multiplication instead.
+ // We check a value against NaN by comparing the value against itself.
+ mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal);
+ mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag);
+ mlir::Value resultRealAndImagAreNaN =
+ builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN);
+
+ return builder
+ .create<cir::TernaryOp>(
+ loc, resultRealAndImagAreNaN,
+ [&](mlir::OpBuilder &, mlir::Location) {
+ mlir::Value libCallResult = buildComplexBinOpLibCall(
+ pass, builder, &getComplexMulLibCallName, loc, complexTy,
+ lhsReal, lhsImag, rhsReal, rhsImag);
+ builder.createYield(loc, libCallResult);
+ },
+ [&](mlir::OpBuilder &, mlir::Location) {
+ builder.createYield(loc, algebraicResult);
+ })
+ .getResult();
+}
+
+void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
+ cir::CIRBaseBuilderTy builder(getContext());
+ builder.setInsertionPointAfter(op);
+ mlir::Location loc = op.getLoc();
+ mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
+ mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
+ mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
+ mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
+ mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
+ mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
+ mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal,
+ lhsImag, rhsReal, rhsImag);
+ op.replaceAllUsesWith(loweredResult);
+ op.erase();
+}
+
void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) {
mlir::Type ty = op.getType();
if (!mlir::isa<cir::ComplexType>(ty))
@@ -247,17 +386,23 @@ void LoweringPreparePass::runOnOp(mlir::Operation *op) {
lowerArrayCtor(arrayCtor);
else if (auto cast = mlir::dyn_cast<cir::CastOp>(op))
lowerCastOp(cast);
+ else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op))
+ lowerComplexMulOp(complexMul);
else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op))
lowerUnaryOp(unary);
}
void LoweringPreparePass::runOnOperation() {
mlir::Operation *op = getOperation();
+ if (isa<::mlir::ModuleOp>(op)) {
+ theModule = cast<::mlir::ModuleOp>(op);
+ }
llvm::SmallVector<mlir::Operation *> opsToTransform;
op->walk([&](mlir::Operation *op) {
- if (mlir::isa<cir::ArrayCtor, cir::CastOp, cir::UnaryOp>(op))
+ if (mlir::isa<cir::ArrayCtor, cir::CastOp, cir::ComplexMulOp, cir::UnaryOp>(
+ op))
opsToTransform.push_back(op);
});
diff --git a/clang/test/CIR/CodeGen/complex-mul-div.cpp b/clang/test/CIR/CodeGen/complex-mul-div.cpp
new file mode 100644
index 0000000000000..5fe682ac28b53
--- /dev/null
+++ b/clang/test/CIR/CodeGen/complex-mul-div.cpp
@@ -0,0 +1,325 @@
+// complex-range basic
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -complex-range=basic -Wno-unused-value -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-canonicalize -o %t.cir %s 2>&1 | FileCheck --check-prefix=CIR-BEFORE-BASIC %s
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=basic -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
+// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR-AFTER-BASIC
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=basic -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
+// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM-BASIC
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=basic -Wno-unused-value -emit-llvm %s -o %t.ll
+// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG-BASIC
+
+// complex-range improved
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -complex-range=improved -Wno-unused-value -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-canonicalize -o %t.cir %s 2>&1 | FileCheck --check-prefix=CIR-BEFORE-IMPROVED %s
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=improved -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
+// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR-AFTER-IMPROVED
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=improved -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
+// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM-IMPROVED
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=improved -Wno-unused-value -emit-llvm %s -o %t.ll
+// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG-IMPROVED
+
+// complex-range promoted
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -complex-range=promoted -Wno-unused-value -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-canonicalize -o %t.cir %s 2>&1 | FileCheck --check-prefix=CIR-BEFORE-PROMOTED %s
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=promoted -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
+// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR-AFTER-PROMOTED
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=promoted -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
+// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM-PROMOTED
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=promoted -Wno-unused-value -emit-llvm %s -o %t.ll
+// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG-PROMOTED
+
+// complex-range full
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -complex-range=full -Wno-unused-value -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-canonicalize -o %t.cir %s 2>&1 | FileCheck --check-prefix=CIR-BEFORE-FULL %s
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=full -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
+// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR-AFTER-FULL
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=full -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
+// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM-FULL
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -complex-range=full -Wno-unused-value -emit-llvm %s -o %t.ll
+// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG-FULL
+
+void foo() {
+ float _Complex a;
+ float _Complex b;
+ float _Complex c = a * b;
+}
+
+// CIR-BEFORE-BASIC: %{{.*}} = cir.complex.mul {{.*}}, {{.*}} range(basic) : !cir.complex<!cir.float>
+
+// CIR-AFTER-BASIC: %[[A_ADDR:.*]] = cir.alloca !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>, ["a"]
+// CIR-AFTER-BASIC: %[[B_ADDR:.*]] = cir.alloca !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>, ["b"]
+// CIR-AFTER-BASIC: %[[C_ADDR:.*]] = cir.alloca !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>, ["c", init]
+// CIR-AFTER-BASIC: %[[TMP_A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.complex<!cir.float>>, !cir.complex<!cir.float>
+// CIR-AFTER-BASIC: %[[TMP_B:.*]] = cir.load{{.*}} %[[B_ADDR]] : !cir.ptr<!cir.complex<!cir.float>>, !cir.complex<!cir.float>
+// CIR-AFTER-BASIC: %[[A_REAL:.*]] = cir.complex.real %[[TMP_A]] : !cir.complex<!cir.float> -> !cir.float
+// CIR-AFTER-BASIC: %[[A_IMAG:.*]] = cir.complex.imag %[[TMP_A]] : !cir.complex<!cir.float> -> !cir.float
+// CIR-AFTER-BASIC: %[[B_REAL:.*]] = cir.complex.real %[[TMP_B]] : !cir.complex<!cir.float> -> !cir.float
+// CIR-AFTER-BASIC: %[[B_IMAG:.*]] = cir.complex.imag %[[TMP_B]] : !cir.complex<!cir.float> -> !cir.float
+// CIR-AFTER-BASIC: %[[MUL_AR_BR:.*]] = cir.binop(mul, %[[A_REAL]], %[[B_REAL]]) : !cir.float
+// CIR-AFTER-BASIC: %[[MUL_AI_BI:.*]] = cir.binop(mul, %[[A_IMAG]], %[[B_IMAG]]) : !cir.float
+// CIR-AFTER-BASIC: %[[MUL_AR_BI:.*]] = cir.binop(mul, %[[A_REAL]], %[[B_IMAG]]) : !cir.float
+// CIR-AFTER-BASIC: %[[MUL_AI_BR:.*]] = cir.binop(mul, %[[A_IMAG]], %[[B_REAL]]) : !cir.float
+// CIR-AFTER-BASIC: %[[C_REAL:.*]] = cir.binop(sub, %[[MUL_AR_BR]], %[[MUL_AI_BI]]) : !cir.float
+// CIR-AFTER-BASIC: %[[C_IMAG:.*]] = cir.binop(add, %[[MUL_AR_BI]], %[[MUL_AI_BR]]) : !cir.float
+// CIR-AFTER-BASIC: %[[RESULT:.*]] = cir.complex.create %[[C_REAL]], %[[C_IMAG]] : !cir.float -> !cir.complex<!cir.float>
+// CIR-AFTER-BASIC: cir.store{{.*}} %[[RESULT]], %[[C_ADDR]] : !cir.complex<!cir.float>, !cir.ptr<!cir.complex<!cir.float>>
+
+// LLVM-BASIC: %[[A_ADDR:.*]] = alloca { float, float }, i64 1, align 4
+// LLVM-BASIC: %[[B_ADDR:.*]] = alloca { float, float }, i64 1, align 4
+// LLVM-BASIC: %[[C_ADDR:.*]] = alloca { float, float }, i64 1, align 4
+// LLVM-BASIC: %[[TMP_A:.*]] = load { float, float }, ptr %[[A_ADDR]], align 4
+// LLVM-BASIC: %[[TMP_B:.*]] = load { float, float }, ptr %[[B_ADDR]], align 4
+// LLVM-BASIC: %[[A_REAL:.*]] = extractvalue { float, float } %[[TMP_A]], 0
+// LLVM-BASIC: %[[A_IMAG:.*]] = extractvalue { float, float } %[[TMP_A]], 1
+// LLVM-BASIC: %[[B_REAL:.*]] = extractvalue { float, float } %[[TMP_B]], 0
+// LLVM-BASIC: %[[B_IMAG:.*]] = extractvalue { float, float } %[[TMP_B]], 1
+// LLVM-BASIC: %[[MUL_AR_BR:.*]] = fmul float %[[A_REAL]], %[[B_REAL]]
+// LLVM-BASIC: %[[MUL_AI_BI:.*]] = fmul float %[[A_IMAG]], %[[B_IMAG]]
+// LLVM-BASIC: %[[MUL_AR_BI:.*]] = fmul float %[[A_REAL]], %[[B_IMAG]]
+// LLVM-BASIC: %[[MUL_AI_BR:.*]] = fmul float %[[A_IMAG]], %[[B_REAL]]
+// LLVM-BASIC: %[[C_REAL:.*]] = fsub float %[[MUL_AR_BR]], %[[MUL_AI_BI]]
+// LLVM-BASIC: %[[C_IMAG:.*]] = fadd float %[[MUL_AR_BI]], %[[MUL_AI_BR]]
+// LLVM-BASIC: %[[MUL_A_B:.*]] = insertvalue { float, float } {{.*}}, float %[[C_REAL]], 0
+// LLVM-BASIC: %[[RESULT:.*]] = insertvalue { float, float } %[[MUL_A_B]], float %[[C_IMAG]], 1
+// LLVM-BASIC: store { float, float } %[[RESULT]], ptr %[[C_ADDR]], align 4
+
+// OGCG-BASIC: %[[A_ADDR:.*]] = alloca { float, float }, align 4
+// OGCG-BASIC: %[[B_ADDR:.*]] = alloca { float, float }, align 4
+// OGCG-BASIC: %[[C_ADDR:.*]] = alloca { float, float }, align 4
+// OGCG-BASIC: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { float, float }, ptr %[[A_ADDR]], i32 0, i32 0
+// OGCG-BASIC: %[[A_REAL:.*]] = load float, ptr %[[A_REAL_PTR]], align 4
+// OGCG-BASIC: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { floa...
[truncated]
|
|
Regarding the option for running the LoweringPrepare pass only if the target file is not CIR, I will need modifications on Complex unary, cast and ArrayCtor. If it's okay, I can update all of them in a follow-up PR |
Seems like a natural incremental change to me. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't need braces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| mlir::ModuleOp theModule; | |
| mlir::ModuleOp mlirModule; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the same set of 'after' and LLVM/OGCG checks for BASIC, IMPROVED, and PROMOTED.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's for improvement and promotes its possible, but with basic we can't use the same set because in complex div basic will be different than the other
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit: I think we can't use the same prefix for them together because once I add ComplexDiv, every RangeKind can produce a different result
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can specify multiple check prefixes on the same run line using --check-prefixes so you could have --check-prefixes=BASIC,MUL_COMBINED for the basic run line and --check-prefixes=IMPROVED,MUL_COMBINED for the improved line, etc., when you need it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated them and created a prefix for int, too.
Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add comment here why full and not none?
b8bb50b to
350726f
Compare
andykaylor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good except for the 'range' description. I've suggested something more verbose, which is what I intended in my original request for that comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| Range is used to controls the various implementations for complex | |
| Range is used to select the implementation used when the operation | |
| is lowered to the LLVM dialect. For multiplication, 'improved', | |
| 'promoted', and 'basic' are all handled equivalently, producing the | |
| algebraic formula with no special handling for NaN value. If 'full' is | |
| used, a runtime-library function is called if one of the intermediate | |
| calculations produced a NaN value. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for writing this :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // Inserting a declaration for the runtime function to be used in Complex | |
| // Insert a declaration for the runtime function to be used in Complex |
6eede5f to
0474585
Compare
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/134/builds/23615 Here is the relevant piece of the build log for the reference |
This change adds support for mul op for ComplexType
#141365