Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,10 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return create<cir::CmpOp>(loc, getBoolTy(), kind, lhs, rhs);
}

mlir::Value createIsNaN(mlir::Location loc, mlir::Value operand) {
return createCompare(loc, cir::CmpOpKind::ne, operand, operand);
}

mlir::Value createShift(mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
bool isShiftLeft) {
return create<cir::ShiftOp>(loc, lhs.getType(), lhs, rhs, isShiftLeft);
Expand Down
47 changes: 47 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2823,6 +2823,53 @@ def CIR_ComplexSubOp : CIR_Op<"complex.sub", [
}];
}

//===----------------------------------------------------------------------===//
// ComplexMulOp
//===----------------------------------------------------------------------===//

def CIR_ComplexRangeKind : CIR_I32EnumAttr<
"ComplexRangeKind", "complex multiplication and division implementation", [
I32EnumAttrCase<"Full", 0, "full">,
I32EnumAttrCase<"Improved", 1, "improved">,
I32EnumAttrCase<"Promoted", 2, "promoted">,
I32EnumAttrCase<"Basic", 3, "basic">,
]>;

def CIR_ComplexMulOp : CIR_Op<"complex.mul", [
Pure, SameOperandsAndResultType
]> {
let summary = "Complex multiplication";
let description = [{
The `cir.complex.mul` operation takes two complex numbers and returns
their product.

Range is used to select the implementation used when the operation
is lowered to the LLVM dialect. For multiplication, 'improved',
'promoted', and 'basic' are all handled equivalently, producing the
algebraic formula with no special handling for NaN value. If 'full' is
used, a runtime-library function is called if one of the intermediate
calculations produced a NaN value.

Example:

```mlir
%2 = cir.complex.mul %0, %1 range(basic) : !cir.complex<!cir.float>
%2 = cir.complex.mul %0, %1 range(full) : !cir.complex<!cir.float>
```
}];

let arguments = (ins
CIR_ComplexType:$lhs,
CIR_ComplexType:$rhs,
CIR_ComplexRangeKind:$range
);

let results = (outs CIR_ComplexType:$result);

let assemblyFormat = [{
$lhs `,` $rhs `range` `(` $range `)` `:` qualified(type($result)) attr-dict
}];
}

//===----------------------------------------------------------------------===//
// Bit Manipulation Operations
Expand Down
28 changes: 28 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {

mlir::Value emitBinAdd(const BinOpInfo &op);
mlir::Value emitBinSub(const BinOpInfo &op);
mlir::Value emitBinMul(const BinOpInfo &op);

QualType getPromotionType(QualType ty, bool isDivOpCode = false) {
if (auto *complexTy = ty->getAs<ComplexType>()) {
Expand Down Expand Up @@ -150,6 +151,7 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {

HANDLEBINOP(Add)
HANDLEBINOP(Sub)
HANDLEBINOP(Mul)
#undef HANDLEBINOP
};
} // namespace
Expand Down Expand Up @@ -577,6 +579,7 @@ mlir::Value ComplexExprEmitter::emitPromoted(const Expr *e,
return emitBin##OP(emitBinOps(bo, promotionTy));
HANDLE_BINOP(Add)
HANDLE_BINOP(Sub)
HANDLE_BINOP(Mul)
#undef HANDLE_BINOP
default:
break;
Expand Down Expand Up @@ -636,6 +639,31 @@ mlir::Value ComplexExprEmitter::emitBinSub(const BinOpInfo &op) {
return builder.create<cir::ComplexSubOp>(op.loc, op.lhs, op.rhs);
}

static cir::ComplexRangeKind
getComplexRangeAttr(LangOptions::ComplexRangeKind range) {
switch (range) {
case LangOptions::CX_Full:
return cir::ComplexRangeKind::Full;
case LangOptions::CX_Improved:
return cir::ComplexRangeKind::Improved;
case LangOptions::CX_Promoted:
return cir::ComplexRangeKind::Promoted;
case LangOptions::CX_Basic:
return cir::ComplexRangeKind::Basic;
case LangOptions::CX_None:
// The default value for ComplexRangeKind is Full is no option is selected
return cir::ComplexRangeKind::Full;
}
}

mlir::Value ComplexExprEmitter::emitBinMul(const BinOpInfo &op) {
assert(!cir::MissingFeatures::fastMathFlags());
assert(!cir::MissingFeatures::cgFPOptionsRAII());
cir::ComplexRangeKind rangeKind =
getComplexRangeAttr(op.fpFeatures.getComplexRange());
return builder.create<cir::ComplexMulOp>(op.loc, op.lhs, op.rhs, rangeKind);
}

LValue CIRGenFunction::emitComplexAssignmentLValue(const BinaryOperator *e) {
assert(e->getOpcode() == BO_Assign && "Expected assign op");

Expand Down
153 changes: 150 additions & 3 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "clang/CIR/Dialect/Passes.h"
#include "clang/CIR/MissingFeatures.h"

#include <iostream>
#include <memory>

using namespace mlir;
Expand All @@ -28,21 +27,47 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {

void runOnOp(mlir::Operation *op);
void lowerCastOp(cir::CastOp op);
void lowerComplexMulOp(cir::ComplexMulOp op);
void lowerUnaryOp(cir::UnaryOp op);
void lowerArrayDtor(cir::ArrayDtor op);
void lowerArrayCtor(cir::ArrayCtor op);

cir::FuncOp buildRuntimeFunction(
mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
cir::FuncType type,
cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);

///
/// AST related
/// -----------

clang::ASTContext *astCtx;

/// Tracks current module.
mlir::ModuleOp mlirModule;

void setASTContext(clang::ASTContext *c) { astCtx = c; }
};

} // namespace

cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
cir::FuncType type, cir::GlobalLinkageKind linkage) {
cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
mlirModule, StringAttr::get(mlirModule->getContext(), name)));
if (!f) {
f = builder.create<cir::FuncOp>(loc, name, type);
f.setLinkageAttr(
cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
mlir::SymbolTable::setSymbolVisibility(
f, mlir::SymbolTable::Visibility::Private);

assert(!cir::MissingFeatures::opFuncExtraAttrs());
}
return f;
}

static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
cir::CastOp op) {
cir::CIRBaseBuilderTy builder(ctx);
Expand Down Expand Up @@ -128,6 +153,124 @@ void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
}
}

static mlir::Value buildComplexBinOpLibCall(
LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
cir::FPTypeInterface elementTy =
mlir::cast<cir::FPTypeInterface>(ty.getElementType());

llvm::StringRef libFuncName = libFuncNameGetter(
llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy);

cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);

// Insert a declaration for the runtime function to be used in Complex
// multiplication and division when needed
cir::FuncOp libFunc;
{
mlir::OpBuilder::InsertionGuard ipGuard{builder};
builder.setInsertionPointToStart(pass.mlirModule.getBody());
libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
}

cir::CallOp call =
builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
return call.getResult();
}

static llvm::StringRef
getComplexMulLibCallName(llvm::APFloat::Semantics semantics) {
switch (semantics) {
case llvm::APFloat::S_IEEEhalf:
return "__mulhc3";
case llvm::APFloat::S_IEEEsingle:
return "__mulsc3";
case llvm::APFloat::S_IEEEdouble:
return "__muldc3";
case llvm::APFloat::S_PPCDoubleDouble:
return "__multc3";
case llvm::APFloat::S_x87DoubleExtended:
return "__mulxc3";
case llvm::APFloat::S_IEEEquad:
return "__multc3";
default:
llvm_unreachable("unsupported floating point type");
}
}

static mlir::Value lowerComplexMul(LoweringPreparePass &pass,
CIRBaseBuilderTy &builder,
mlir::Location loc, cir::ComplexMulOp op,
mlir::Value lhsReal, mlir::Value lhsImag,
mlir::Value rhsReal, mlir::Value rhsImag) {
// (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
mlir::Value resultRealLhs =
builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsReal);
mlir::Value resultRealRhs =
builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsImag);
mlir::Value resultImagLhs =
builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsImag);
mlir::Value resultImagRhs =
builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsReal);
mlir::Value resultReal = builder.createBinop(
loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs);
mlir::Value resultImag = builder.createBinop(
loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs);
mlir::Value algebraicResult =
builder.createComplexCreate(loc, resultReal, resultImag);

cir::ComplexType complexTy = op.getType();
cir::ComplexRangeKind rangeKind = op.getRange();
if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
rangeKind == cir::ComplexRangeKind::Basic ||
rangeKind == cir::ComplexRangeKind::Improved ||
rangeKind == cir::ComplexRangeKind::Promoted)
return algebraicResult;

assert(!cir::MissingFeatures::fastMathFlags());

// Check whether the real part and the imaginary part of the result are both
// NaN. If so, emit a library call to compute the multiplication instead.
// We check a value against NaN by comparing the value against itself.
mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal);
mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag);
mlir::Value resultRealAndImagAreNaN =
builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN);

return builder
.create<cir::TernaryOp>(
loc, resultRealAndImagAreNaN,
[&](mlir::OpBuilder &, mlir::Location) {
mlir::Value libCallResult = buildComplexBinOpLibCall(
pass, builder, &getComplexMulLibCallName, loc, complexTy,
lhsReal, lhsImag, rhsReal, rhsImag);
builder.createYield(loc, libCallResult);
},
[&](mlir::OpBuilder &, mlir::Location) {
builder.createYield(loc, algebraicResult);
})
.getResult();
}

void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
cir::CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
mlir::Location loc = op.getLoc();
mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal,
lhsImag, rhsReal, rhsImag);
op.replaceAllUsesWith(loweredResult);
op.erase();
}

void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) {
mlir::Type ty = op.getType();
if (!mlir::isa<cir::ComplexType>(ty))
Expand Down Expand Up @@ -269,18 +412,22 @@ void LoweringPreparePass::runOnOp(mlir::Operation *op) {
lowerArrayDtor(arrayDtor);
else if (auto cast = mlir::dyn_cast<cir::CastOp>(op))
lowerCastOp(cast);
else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op))
lowerComplexMulOp(complexMul);
else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op))
lowerUnaryOp(unary);
}

void LoweringPreparePass::runOnOperation() {
mlir::Operation *op = getOperation();
if (isa<::mlir::ModuleOp>(op))
mlirModule = cast<::mlir::ModuleOp>(op);

llvm::SmallVector<mlir::Operation *> opsToTransform;

op->walk([&](mlir::Operation *op) {
if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp, cir::UnaryOp>(
op))
if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
cir::ComplexMulOp, cir::UnaryOp>(op))
opsToTransform.push_back(op);
});

Expand Down
Loading