1515#include " clang/CIR/Dialect/Passes.h"
1616#include " clang/CIR/MissingFeatures.h"
1717
18- #include < iostream>
1918#include < memory>
2019
2120using namespace mlir ;
@@ -28,21 +27,47 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
2827
2928 void runOnOp (mlir::Operation *op);
3029 void lowerCastOp (cir::CastOp op);
30+ void lowerComplexMulOp (cir::ComplexMulOp op);
3131 void lowerUnaryOp (cir::UnaryOp op);
3232 void lowerArrayDtor (cir::ArrayDtor op);
3333 void lowerArrayCtor (cir::ArrayCtor op);
3434
35+ cir::FuncOp buildRuntimeFunction (
36+ mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
37+ cir::FuncType type,
38+ cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
39+
3540 // /
3641 // / AST related
3742 // / -----------
3843
3944 clang::ASTContext *astCtx;
4045
46+ // / Tracks current module.
47+ mlir::ModuleOp mlirModule;
48+
4149 void setASTContext (clang::ASTContext *c) { astCtx = c; }
4250};
4351
4452} // namespace
4553
54+ cir::FuncOp LoweringPreparePass::buildRuntimeFunction (
55+ mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
56+ cir::FuncType type, cir::GlobalLinkageKind linkage) {
57+ cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom (
58+ mlirModule, StringAttr::get (mlirModule->getContext (), name)));
59+ if (!f) {
60+ f = builder.create <cir::FuncOp>(loc, name, type);
61+ f.setLinkageAttr (
62+ cir::GlobalLinkageKindAttr::get (builder.getContext (), linkage));
63+ mlir::SymbolTable::setSymbolVisibility (
64+ f, mlir::SymbolTable::Visibility::Private);
65+
66+ assert (!cir::MissingFeatures::opFuncExtraAttrs ());
67+ }
68+ return f;
69+ }
70+
4671static mlir::Value lowerScalarToComplexCast (mlir::MLIRContext &ctx,
4772 cir::CastOp op) {
4873 cir::CIRBaseBuilderTy builder (ctx);
@@ -128,6 +153,124 @@ void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
128153 }
129154}
130155
156+ static mlir::Value buildComplexBinOpLibCall (
157+ LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
158+ llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
159+ mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
160+ mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
161+ cir::FPTypeInterface elementTy =
162+ mlir::cast<cir::FPTypeInterface>(ty.getElementType ());
163+
164+ llvm::StringRef libFuncName = libFuncNameGetter (
165+ llvm::APFloat::SemanticsToEnum (elementTy.getFloatSemantics ()));
166+ llvm::SmallVector<mlir::Type, 4 > libFuncInputTypes (4 , elementTy);
167+
168+ cir::FuncType libFuncTy = cir::FuncType::get (libFuncInputTypes, ty);
169+
170+ // Insert a declaration for the runtime function to be used in Complex
171+ // multiplication and division when needed
172+ cir::FuncOp libFunc;
173+ {
174+ mlir::OpBuilder::InsertionGuard ipGuard{builder};
175+ builder.setInsertionPointToStart (pass.mlirModule .getBody ());
176+ libFunc = pass.buildRuntimeFunction (builder, libFuncName, loc, libFuncTy);
177+ }
178+
179+ cir::CallOp call =
180+ builder.createCallOp (loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
181+ return call.getResult ();
182+ }
183+
184+ static llvm::StringRef
185+ getComplexMulLibCallName (llvm::APFloat::Semantics semantics) {
186+ switch (semantics) {
187+ case llvm::APFloat::S_IEEEhalf:
188+ return " __mulhc3" ;
189+ case llvm::APFloat::S_IEEEsingle:
190+ return " __mulsc3" ;
191+ case llvm::APFloat::S_IEEEdouble:
192+ return " __muldc3" ;
193+ case llvm::APFloat::S_PPCDoubleDouble:
194+ return " __multc3" ;
195+ case llvm::APFloat::S_x87DoubleExtended:
196+ return " __mulxc3" ;
197+ case llvm::APFloat::S_IEEEquad:
198+ return " __multc3" ;
199+ default :
200+ llvm_unreachable (" unsupported floating point type" );
201+ }
202+ }
203+
204+ static mlir::Value lowerComplexMul (LoweringPreparePass &pass,
205+ CIRBaseBuilderTy &builder,
206+ mlir::Location loc, cir::ComplexMulOp op,
207+ mlir::Value lhsReal, mlir::Value lhsImag,
208+ mlir::Value rhsReal, mlir::Value rhsImag) {
209+ // (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
210+ mlir::Value resultRealLhs =
211+ builder.createBinop (loc, lhsReal, cir::BinOpKind::Mul, rhsReal);
212+ mlir::Value resultRealRhs =
213+ builder.createBinop (loc, lhsImag, cir::BinOpKind::Mul, rhsImag);
214+ mlir::Value resultImagLhs =
215+ builder.createBinop (loc, lhsReal, cir::BinOpKind::Mul, rhsImag);
216+ mlir::Value resultImagRhs =
217+ builder.createBinop (loc, lhsImag, cir::BinOpKind::Mul, rhsReal);
218+ mlir::Value resultReal = builder.createBinop (
219+ loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs);
220+ mlir::Value resultImag = builder.createBinop (
221+ loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs);
222+ mlir::Value algebraicResult =
223+ builder.createComplexCreate (loc, resultReal, resultImag);
224+
225+ cir::ComplexType complexTy = op.getType ();
226+ cir::ComplexRangeKind rangeKind = op.getRange ();
227+ if (mlir::isa<cir::IntType>(complexTy.getElementType ()) ||
228+ rangeKind == cir::ComplexRangeKind::Basic ||
229+ rangeKind == cir::ComplexRangeKind::Improved ||
230+ rangeKind == cir::ComplexRangeKind::Promoted)
231+ return algebraicResult;
232+
233+ assert (!cir::MissingFeatures::fastMathFlags ());
234+
235+ // Check whether the real part and the imaginary part of the result are both
236+ // NaN. If so, emit a library call to compute the multiplication instead.
237+ // We check a value against NaN by comparing the value against itself.
238+ mlir::Value resultRealIsNaN = builder.createIsNaN (loc, resultReal);
239+ mlir::Value resultImagIsNaN = builder.createIsNaN (loc, resultImag);
240+ mlir::Value resultRealAndImagAreNaN =
241+ builder.createLogicalAnd (loc, resultRealIsNaN, resultImagIsNaN);
242+
243+ return builder
244+ .create <cir::TernaryOp>(
245+ loc, resultRealAndImagAreNaN,
246+ [&](mlir::OpBuilder &, mlir::Location) {
247+ mlir::Value libCallResult = buildComplexBinOpLibCall (
248+ pass, builder, &getComplexMulLibCallName, loc, complexTy,
249+ lhsReal, lhsImag, rhsReal, rhsImag);
250+ builder.createYield (loc, libCallResult);
251+ },
252+ [&](mlir::OpBuilder &, mlir::Location) {
253+ builder.createYield (loc, algebraicResult);
254+ })
255+ .getResult ();
256+ }
257+
258+ void LoweringPreparePass::lowerComplexMulOp (cir::ComplexMulOp op) {
259+ cir::CIRBaseBuilderTy builder (getContext ());
260+ builder.setInsertionPointAfter (op);
261+ mlir::Location loc = op.getLoc ();
262+ mlir::TypedValue<cir::ComplexType> lhs = op.getLhs ();
263+ mlir::TypedValue<cir::ComplexType> rhs = op.getRhs ();
264+ mlir::Value lhsReal = builder.createComplexReal (loc, lhs);
265+ mlir::Value lhsImag = builder.createComplexImag (loc, lhs);
266+ mlir::Value rhsReal = builder.createComplexReal (loc, rhs);
267+ mlir::Value rhsImag = builder.createComplexImag (loc, rhs);
268+ mlir::Value loweredResult = lowerComplexMul (*this , builder, loc, op, lhsReal,
269+ lhsImag, rhsReal, rhsImag);
270+ op.replaceAllUsesWith (loweredResult);
271+ op.erase ();
272+ }
273+
131274void LoweringPreparePass::lowerUnaryOp (cir::UnaryOp op) {
132275 mlir::Type ty = op.getType ();
133276 if (!mlir::isa<cir::ComplexType>(ty))
@@ -269,18 +412,22 @@ void LoweringPreparePass::runOnOp(mlir::Operation *op) {
269412 lowerArrayDtor (arrayDtor);
270413 else if (auto cast = mlir::dyn_cast<cir::CastOp>(op))
271414 lowerCastOp (cast);
415+ else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op))
416+ lowerComplexMulOp (complexMul);
272417 else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op))
273418 lowerUnaryOp (unary);
274419}
275420
276421void LoweringPreparePass::runOnOperation () {
277422 mlir::Operation *op = getOperation ();
423+ if (isa<::mlir::ModuleOp>(op))
424+ mlirModule = cast<::mlir::ModuleOp>(op);
278425
279426 llvm::SmallVector<mlir::Operation *> opsToTransform;
280427
281428 op->walk ([&](mlir::Operation *op) {
282- if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp, cir::UnaryOp>(
283- op))
429+ if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
430+ cir::ComplexMulOp, cir::UnaryOp>( op))
284431 opsToTransform.push_back (op);
285432 });
286433
0 commit comments