Skip to content

Commit 96d4a91

Browse files
AmrDeveloperkrishna2803
authored andcommitted
[CIR] Upstream MulOp for ComplexType (llvm#150834)
This change adds support for mul op for ComplexType llvm#141365
1 parent deeecff commit 96d4a91

File tree

5 files changed

+509
-3
lines changed

5 files changed

+509
-3
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,10 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
447447
return create<cir::CmpOp>(loc, getBoolTy(), kind, lhs, rhs);
448448
}
449449

450+
mlir::Value createIsNaN(mlir::Location loc, mlir::Value operand) {
451+
return createCompare(loc, cir::CmpOpKind::ne, operand, operand);
452+
}
453+
450454
mlir::Value createShift(mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
451455
bool isShiftLeft) {
452456
return create<cir::ShiftOp>(loc, lhs.getType(), lhs, rhs, isShiftLeft);

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2823,6 +2823,53 @@ def CIR_ComplexSubOp : CIR_Op<"complex.sub", [
28232823
}];
28242824
}
28252825

2826+
//===----------------------------------------------------------------------===//
2827+
// ComplexMulOp
2828+
//===----------------------------------------------------------------------===//
2829+
2830+
def CIR_ComplexRangeKind : CIR_I32EnumAttr<
2831+
"ComplexRangeKind", "complex multiplication and division implementation", [
2832+
I32EnumAttrCase<"Full", 0, "full">,
2833+
I32EnumAttrCase<"Improved", 1, "improved">,
2834+
I32EnumAttrCase<"Promoted", 2, "promoted">,
2835+
I32EnumAttrCase<"Basic", 3, "basic">,
2836+
]>;
2837+
2838+
def CIR_ComplexMulOp : CIR_Op<"complex.mul", [
2839+
Pure, SameOperandsAndResultType
2840+
]> {
2841+
let summary = "Complex multiplication";
2842+
let description = [{
2843+
The `cir.complex.mul` operation takes two complex numbers and returns
2844+
their product.
2845+
2846+
Range is used to select the implementation used when the operation
2847+
is lowered to the LLVM dialect. For multiplication, 'improved',
2848+
'promoted', and 'basic' are all handled equivalently, producing the
2849+
algebraic formula with no special handling for NaN value. If 'full' is
2850+
used, a runtime-library function is called if one of the intermediate
2851+
calculations produced a NaN value.
2852+
2853+
Example:
2854+
2855+
```mlir
2856+
%2 = cir.complex.mul %0, %1 range(basic) : !cir.complex<!cir.float>
2857+
%2 = cir.complex.mul %0, %1 range(full) : !cir.complex<!cir.float>
2858+
```
2859+
}];
2860+
2861+
let arguments = (ins
2862+
CIR_ComplexType:$lhs,
2863+
CIR_ComplexType:$rhs,
2864+
CIR_ComplexRangeKind:$range
2865+
);
2866+
2867+
let results = (outs CIR_ComplexType:$result);
2868+
2869+
let assemblyFormat = [{
2870+
$lhs `,` $rhs `range` `(` $range `)` `:` qualified(type($result)) attr-dict
2871+
}];
2872+
}
28262873

28272874
//===----------------------------------------------------------------------===//
28282875
// Bit Manipulation Operations

clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
118118

119119
mlir::Value emitBinAdd(const BinOpInfo &op);
120120
mlir::Value emitBinSub(const BinOpInfo &op);
121+
mlir::Value emitBinMul(const BinOpInfo &op);
121122

122123
QualType getPromotionType(QualType ty, bool isDivOpCode = false) {
123124
if (auto *complexTy = ty->getAs<ComplexType>()) {
@@ -150,6 +151,7 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
150151

151152
HANDLEBINOP(Add)
152153
HANDLEBINOP(Sub)
154+
HANDLEBINOP(Mul)
153155
#undef HANDLEBINOP
154156
};
155157
} // namespace
@@ -577,6 +579,7 @@ mlir::Value ComplexExprEmitter::emitPromoted(const Expr *e,
577579
return emitBin##OP(emitBinOps(bo, promotionTy));
578580
HANDLE_BINOP(Add)
579581
HANDLE_BINOP(Sub)
582+
HANDLE_BINOP(Mul)
580583
#undef HANDLE_BINOP
581584
default:
582585
break;
@@ -636,6 +639,31 @@ mlir::Value ComplexExprEmitter::emitBinSub(const BinOpInfo &op) {
636639
return builder.create<cir::ComplexSubOp>(op.loc, op.lhs, op.rhs);
637640
}
638641

642+
static cir::ComplexRangeKind
643+
getComplexRangeAttr(LangOptions::ComplexRangeKind range) {
644+
switch (range) {
645+
case LangOptions::CX_Full:
646+
return cir::ComplexRangeKind::Full;
647+
case LangOptions::CX_Improved:
648+
return cir::ComplexRangeKind::Improved;
649+
case LangOptions::CX_Promoted:
650+
return cir::ComplexRangeKind::Promoted;
651+
case LangOptions::CX_Basic:
652+
return cir::ComplexRangeKind::Basic;
653+
case LangOptions::CX_None:
654+
// The default value for ComplexRangeKind is Full is no option is selected
655+
return cir::ComplexRangeKind::Full;
656+
}
657+
}
658+
659+
mlir::Value ComplexExprEmitter::emitBinMul(const BinOpInfo &op) {
660+
assert(!cir::MissingFeatures::fastMathFlags());
661+
assert(!cir::MissingFeatures::cgFPOptionsRAII());
662+
cir::ComplexRangeKind rangeKind =
663+
getComplexRangeAttr(op.fpFeatures.getComplexRange());
664+
return builder.create<cir::ComplexMulOp>(op.loc, op.lhs, op.rhs, rangeKind);
665+
}
666+
639667
LValue CIRGenFunction::emitComplexAssignmentLValue(const BinaryOperator *e) {
640668
assert(e->getOpcode() == BO_Assign && "Expected assign op");
641669

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

Lines changed: 150 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include "clang/CIR/Dialect/Passes.h"
1616
#include "clang/CIR/MissingFeatures.h"
1717

18-
#include <iostream>
1918
#include <memory>
2019

2120
using namespace mlir;
@@ -28,21 +27,47 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
2827

2928
void runOnOp(mlir::Operation *op);
3029
void lowerCastOp(cir::CastOp op);
30+
void lowerComplexMulOp(cir::ComplexMulOp op);
3131
void lowerUnaryOp(cir::UnaryOp op);
3232
void lowerArrayDtor(cir::ArrayDtor op);
3333
void lowerArrayCtor(cir::ArrayCtor op);
3434

35+
cir::FuncOp buildRuntimeFunction(
36+
mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
37+
cir::FuncType type,
38+
cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
39+
3540
///
3641
/// AST related
3742
/// -----------
3843

3944
clang::ASTContext *astCtx;
4045

46+
/// Tracks current module.
47+
mlir::ModuleOp mlirModule;
48+
4149
void setASTContext(clang::ASTContext *c) { astCtx = c; }
4250
};
4351

4452
} // namespace
4553

54+
cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
55+
mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
56+
cir::FuncType type, cir::GlobalLinkageKind linkage) {
57+
cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
58+
mlirModule, StringAttr::get(mlirModule->getContext(), name)));
59+
if (!f) {
60+
f = builder.create<cir::FuncOp>(loc, name, type);
61+
f.setLinkageAttr(
62+
cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
63+
mlir::SymbolTable::setSymbolVisibility(
64+
f, mlir::SymbolTable::Visibility::Private);
65+
66+
assert(!cir::MissingFeatures::opFuncExtraAttrs());
67+
}
68+
return f;
69+
}
70+
4671
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
4772
cir::CastOp op) {
4873
cir::CIRBaseBuilderTy builder(ctx);
@@ -128,6 +153,124 @@ void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
128153
}
129154
}
130155

156+
static mlir::Value buildComplexBinOpLibCall(
157+
LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
158+
llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
159+
mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
160+
mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
161+
cir::FPTypeInterface elementTy =
162+
mlir::cast<cir::FPTypeInterface>(ty.getElementType());
163+
164+
llvm::StringRef libFuncName = libFuncNameGetter(
165+
llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
166+
llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy);
167+
168+
cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);
169+
170+
// Insert a declaration for the runtime function to be used in Complex
171+
// multiplication and division when needed
172+
cir::FuncOp libFunc;
173+
{
174+
mlir::OpBuilder::InsertionGuard ipGuard{builder};
175+
builder.setInsertionPointToStart(pass.mlirModule.getBody());
176+
libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
177+
}
178+
179+
cir::CallOp call =
180+
builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
181+
return call.getResult();
182+
}
183+
184+
static llvm::StringRef
185+
getComplexMulLibCallName(llvm::APFloat::Semantics semantics) {
186+
switch (semantics) {
187+
case llvm::APFloat::S_IEEEhalf:
188+
return "__mulhc3";
189+
case llvm::APFloat::S_IEEEsingle:
190+
return "__mulsc3";
191+
case llvm::APFloat::S_IEEEdouble:
192+
return "__muldc3";
193+
case llvm::APFloat::S_PPCDoubleDouble:
194+
return "__multc3";
195+
case llvm::APFloat::S_x87DoubleExtended:
196+
return "__mulxc3";
197+
case llvm::APFloat::S_IEEEquad:
198+
return "__multc3";
199+
default:
200+
llvm_unreachable("unsupported floating point type");
201+
}
202+
}
203+
204+
static mlir::Value lowerComplexMul(LoweringPreparePass &pass,
205+
CIRBaseBuilderTy &builder,
206+
mlir::Location loc, cir::ComplexMulOp op,
207+
mlir::Value lhsReal, mlir::Value lhsImag,
208+
mlir::Value rhsReal, mlir::Value rhsImag) {
209+
// (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
210+
mlir::Value resultRealLhs =
211+
builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsReal);
212+
mlir::Value resultRealRhs =
213+
builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsImag);
214+
mlir::Value resultImagLhs =
215+
builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsImag);
216+
mlir::Value resultImagRhs =
217+
builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsReal);
218+
mlir::Value resultReal = builder.createBinop(
219+
loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs);
220+
mlir::Value resultImag = builder.createBinop(
221+
loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs);
222+
mlir::Value algebraicResult =
223+
builder.createComplexCreate(loc, resultReal, resultImag);
224+
225+
cir::ComplexType complexTy = op.getType();
226+
cir::ComplexRangeKind rangeKind = op.getRange();
227+
if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
228+
rangeKind == cir::ComplexRangeKind::Basic ||
229+
rangeKind == cir::ComplexRangeKind::Improved ||
230+
rangeKind == cir::ComplexRangeKind::Promoted)
231+
return algebraicResult;
232+
233+
assert(!cir::MissingFeatures::fastMathFlags());
234+
235+
// Check whether the real part and the imaginary part of the result are both
236+
// NaN. If so, emit a library call to compute the multiplication instead.
237+
// We check a value against NaN by comparing the value against itself.
238+
mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal);
239+
mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag);
240+
mlir::Value resultRealAndImagAreNaN =
241+
builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN);
242+
243+
return builder
244+
.create<cir::TernaryOp>(
245+
loc, resultRealAndImagAreNaN,
246+
[&](mlir::OpBuilder &, mlir::Location) {
247+
mlir::Value libCallResult = buildComplexBinOpLibCall(
248+
pass, builder, &getComplexMulLibCallName, loc, complexTy,
249+
lhsReal, lhsImag, rhsReal, rhsImag);
250+
builder.createYield(loc, libCallResult);
251+
},
252+
[&](mlir::OpBuilder &, mlir::Location) {
253+
builder.createYield(loc, algebraicResult);
254+
})
255+
.getResult();
256+
}
257+
258+
void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
259+
cir::CIRBaseBuilderTy builder(getContext());
260+
builder.setInsertionPointAfter(op);
261+
mlir::Location loc = op.getLoc();
262+
mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
263+
mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
264+
mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
265+
mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
266+
mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
267+
mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
268+
mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal,
269+
lhsImag, rhsReal, rhsImag);
270+
op.replaceAllUsesWith(loweredResult);
271+
op.erase();
272+
}
273+
131274
void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) {
132275
mlir::Type ty = op.getType();
133276
if (!mlir::isa<cir::ComplexType>(ty))
@@ -269,18 +412,22 @@ void LoweringPreparePass::runOnOp(mlir::Operation *op) {
269412
lowerArrayDtor(arrayDtor);
270413
else if (auto cast = mlir::dyn_cast<cir::CastOp>(op))
271414
lowerCastOp(cast);
415+
else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op))
416+
lowerComplexMulOp(complexMul);
272417
else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op))
273418
lowerUnaryOp(unary);
274419
}
275420

276421
void LoweringPreparePass::runOnOperation() {
277422
mlir::Operation *op = getOperation();
423+
if (isa<::mlir::ModuleOp>(op))
424+
mlirModule = cast<::mlir::ModuleOp>(op);
278425

279426
llvm::SmallVector<mlir::Operation *> opsToTransform;
280427

281428
op->walk([&](mlir::Operation *op) {
282-
if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp, cir::UnaryOp>(
283-
op))
429+
if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
430+
cir::ComplexMulOp, cir::UnaryOp>(op))
284431
opsToTransform.push_back(op);
285432
});
286433

0 commit comments

Comments
 (0)