88
99#include " PassDetail.h"
1010#include " clang/AST/ASTContext.h"
11- #include " clang/AST/CharUnits.h"
1211#include " clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
1312#include " clang/CIR/Dialect/IR/CIRDialect.h"
1413#include " clang/CIR/Dialect/IR/CIROpsEnums.h"
@@ -27,6 +26,7 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
2726
2827 void runOnOp (mlir::Operation *op);
2928 void lowerCastOp (cir::CastOp op);
29+ void lowerComplexDivOp (cir::ComplexDivOp op);
3030 void lowerComplexMulOp (cir::ComplexMulOp op);
3131 void lowerUnaryOp (cir::UnaryOp op);
3232 void lowerArrayDtor (cir::ArrayDtor op);
@@ -181,6 +181,176 @@ static mlir::Value buildComplexBinOpLibCall(
181181 return call.getResult ();
182182}
183183
184+ static llvm::StringRef
185+ getComplexDivLibCallName (llvm::APFloat::Semantics semantics) {
186+ switch (semantics) {
187+ case llvm::APFloat::S_IEEEhalf:
188+ return " __divhc3" ;
189+ case llvm::APFloat::S_IEEEsingle:
190+ return " __divsc3" ;
191+ case llvm::APFloat::S_IEEEdouble:
192+ return " __divdc3" ;
193+ case llvm::APFloat::S_PPCDoubleDouble:
194+ return " __divtc3" ;
195+ case llvm::APFloat::S_x87DoubleExtended:
196+ return " __divxc3" ;
197+ case llvm::APFloat::S_IEEEquad:
198+ return " __divtc3" ;
199+ default :
200+ llvm_unreachable (" unsupported floating point type" );
201+ }
202+ }
203+
204+ static mlir::Value
205+ buildAlgebraicComplexDiv (CIRBaseBuilderTy &builder, mlir::Location loc,
206+ mlir::Value lhsReal, mlir::Value lhsImag,
207+ mlir::Value rhsReal, mlir::Value rhsImag) {
208+ // (a+bi) / (c+di) = ((ac+bd)/(cc+dd)) + ((bc-ad)/(cc+dd))i
209+ mlir::Value &a = lhsReal;
210+ mlir::Value &b = lhsImag;
211+ mlir::Value &c = rhsReal;
212+ mlir::Value &d = rhsImag;
213+
214+ mlir::Value ac = builder.createBinop (loc, a, cir::BinOpKind::Mul, c); // a*c
215+ mlir::Value bd = builder.createBinop (loc, b, cir::BinOpKind::Mul, d); // b*d
216+ mlir::Value cc = builder.createBinop (loc, c, cir::BinOpKind::Mul, c); // c*c
217+ mlir::Value dd = builder.createBinop (loc, d, cir::BinOpKind::Mul, d); // d*d
218+ mlir::Value acbd =
219+ builder.createBinop (loc, ac, cir::BinOpKind::Add, bd); // ac+bd
220+ mlir::Value ccdd =
221+ builder.createBinop (loc, cc, cir::BinOpKind::Add, dd); // cc+dd
222+ mlir::Value resultReal =
223+ builder.createBinop (loc, acbd, cir::BinOpKind::Div, ccdd);
224+
225+ mlir::Value bc = builder.createBinop (loc, b, cir::BinOpKind::Mul, c); // b*c
226+ mlir::Value ad = builder.createBinop (loc, a, cir::BinOpKind::Mul, d); // a*d
227+ mlir::Value bcad =
228+ builder.createBinop (loc, bc, cir::BinOpKind::Sub, ad); // bc-ad
229+ mlir::Value resultImag =
230+ builder.createBinop (loc, bcad, cir::BinOpKind::Div, ccdd);
231+ return builder.createComplexCreate (loc, resultReal, resultImag);
232+ }
233+
234+ static mlir::Value
235+ buildRangeReductionComplexDiv (CIRBaseBuilderTy &builder, mlir::Location loc,
236+ mlir::Value lhsReal, mlir::Value lhsImag,
237+ mlir::Value rhsReal, mlir::Value rhsImag) {
238+ // Implements Smith's algorithm for complex division.
239+ // SMITH, R. L. Algorithm 116: Complex division. Commun. ACM 5, 8 (1962).
240+
241+ // Let:
242+ // - lhs := a+bi
243+ // - rhs := c+di
244+ // - result := lhs / rhs = e+fi
245+ //
246+ // The algorithm pseudocode looks like follows:
247+ // if fabs(c) >= fabs(d):
248+ // r := d / c
249+ // tmp := c + r*d
250+ // e = (a + b*r) / tmp
251+ // f = (b - a*r) / tmp
252+ // else:
253+ // r := c / d
254+ // tmp := d + r*c
255+ // e = (a*r + b) / tmp
256+ // f = (b*r - a) / tmp
257+
258+ mlir::Value &a = lhsReal;
259+ mlir::Value &b = lhsImag;
260+ mlir::Value &c = rhsReal;
261+ mlir::Value &d = rhsImag;
262+
263+ auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
264+ mlir::Value r = builder.createBinop (loc, d, cir::BinOpKind::Div,
265+ c); // r := d / c
266+ mlir::Value rd = builder.createBinop (loc, r, cir::BinOpKind::Mul, d); // r*d
267+ mlir::Value tmp = builder.createBinop (loc, c, cir::BinOpKind::Add,
268+ rd); // tmp := c + r*d
269+
270+ mlir::Value br = builder.createBinop (loc, b, cir::BinOpKind::Mul, r); // b*r
271+ mlir::Value abr =
272+ builder.createBinop (loc, a, cir::BinOpKind::Add, br); // a + b*r
273+ mlir::Value e = builder.createBinop (loc, abr, cir::BinOpKind::Div, tmp);
274+
275+ mlir::Value ar = builder.createBinop (loc, a, cir::BinOpKind::Mul, r); // a*r
276+ mlir::Value bar =
277+ builder.createBinop (loc, b, cir::BinOpKind::Sub, ar); // b - a*r
278+ mlir::Value f = builder.createBinop (loc, bar, cir::BinOpKind::Div, tmp);
279+
280+ mlir::Value result = builder.createComplexCreate (loc, e, f);
281+ builder.createYield (loc, result);
282+ };
283+
284+ auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
285+ mlir::Value r = builder.createBinop (loc, c, cir::BinOpKind::Div,
286+ d); // r := c / d
287+ mlir::Value rc = builder.createBinop (loc, r, cir::BinOpKind::Mul, c); // r*c
288+ mlir::Value tmp = builder.createBinop (loc, d, cir::BinOpKind::Add,
289+ rc); // tmp := d + r*c
290+
291+ mlir::Value ar = builder.createBinop (loc, a, cir::BinOpKind::Mul, r); // a*r
292+ mlir::Value arb =
293+ builder.createBinop (loc, ar, cir::BinOpKind::Add, b); // a*r + b
294+ mlir::Value e = builder.createBinop (loc, arb, cir::BinOpKind::Div, tmp);
295+
296+ mlir::Value br = builder.createBinop (loc, b, cir::BinOpKind::Mul, r); // b*r
297+ mlir::Value bra =
298+ builder.createBinop (loc, br, cir::BinOpKind::Sub, a); // b*r - a
299+ mlir::Value f = builder.createBinop (loc, bra, cir::BinOpKind::Div, tmp);
300+
301+ mlir::Value result = builder.createComplexCreate (loc, e, f);
302+ builder.createYield (loc, result);
303+ };
304+
305+ auto cFabs = builder.create <cir::FAbsOp>(loc, c);
306+ auto dFabs = builder.create <cir::FAbsOp>(loc, d);
307+ cir::CmpOp cmpResult =
308+ builder.createCompare (loc, cir::CmpOpKind::ge, cFabs, dFabs);
309+ auto ternary = builder.create <cir::TernaryOp>(
310+ loc, cmpResult, trueBranchBuilder, falseBranchBuilder);
311+
312+ return ternary.getResult ();
313+ }
314+
315+ static mlir::Value lowerComplexDiv (LoweringPreparePass &pass,
316+ CIRBaseBuilderTy &builder,
317+ mlir::Location loc, cir::ComplexDivOp op,
318+ mlir::Value lhsReal, mlir::Value lhsImag,
319+ mlir::Value rhsReal, mlir::Value rhsImag) {
320+ cir::ComplexType complexTy = op.getType ();
321+ if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType ())) {
322+ cir::ComplexRangeKind range = op.getRange ();
323+ if (range == cir::ComplexRangeKind::Improved ||
324+ (range == cir::ComplexRangeKind::Promoted && !op.getPromoted ()))
325+ return buildRangeReductionComplexDiv (builder, loc, lhsReal, lhsImag,
326+ rhsReal, rhsImag);
327+ if (range == cir::ComplexRangeKind::Full)
328+ return buildComplexBinOpLibCall (pass, builder, &getComplexDivLibCallName,
329+ loc, complexTy, lhsReal, lhsImag, rhsReal,
330+ rhsImag);
331+ }
332+
333+ return buildAlgebraicComplexDiv (builder, loc, lhsReal, lhsImag, rhsReal,
334+ rhsImag);
335+ }
336+
337+ void LoweringPreparePass::lowerComplexDivOp (cir::ComplexDivOp op) {
338+ cir::CIRBaseBuilderTy builder (getContext ());
339+ builder.setInsertionPointAfter (op);
340+ mlir::Location loc = op.getLoc ();
341+ mlir::TypedValue<cir::ComplexType> lhs = op.getLhs ();
342+ mlir::TypedValue<cir::ComplexType> rhs = op.getRhs ();
343+ mlir::Value lhsReal = builder.createComplexReal (loc, lhs);
344+ mlir::Value lhsImag = builder.createComplexImag (loc, lhs);
345+ mlir::Value rhsReal = builder.createComplexReal (loc, rhs);
346+ mlir::Value rhsImag = builder.createComplexImag (loc, rhs);
347+
348+ mlir::Value loweredResult = lowerComplexDiv (*this , builder, loc, op, lhsReal,
349+ lhsImag, rhsReal, rhsImag);
350+ op.replaceAllUsesWith (loweredResult);
351+ op.erase ();
352+ }
353+
184354static llvm::StringRef
185355getComplexMulLibCallName (llvm::APFloat::Semantics semantics) {
186356 switch (semantics) {
@@ -412,6 +582,8 @@ void LoweringPreparePass::runOnOp(mlir::Operation *op) {
412582 lowerArrayDtor (arrayDtor);
413583 else if (auto cast = mlir::dyn_cast<cir::CastOp>(op))
414584 lowerCastOp (cast);
585+ else if (auto complexDiv = mlir::dyn_cast<cir::ComplexDivOp>(op))
586+ lowerComplexDivOp (complexDiv);
415587 else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op))
416588 lowerComplexMulOp (complexMul);
417589 else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op))
@@ -427,7 +599,7 @@ void LoweringPreparePass::runOnOperation() {
427599
428600 op->walk ([&](mlir::Operation *op) {
429601 if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
430- cir::ComplexMulOp, cir::UnaryOp>(op))
602+ cir::ComplexMulOp, cir::ComplexDivOp, cir:: UnaryOp>(op))
431603 opsToTransform.push_back (op);
432604 });
433605
0 commit comments