88
99#include " PassDetail.h"
1010#include " clang/AST/ASTContext.h"
11+ #include " clang/Basic/TargetInfo.h"
1112#include " clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
1213#include " clang/CIR/Dialect/IR/CIRDialect.h"
1314#include " clang/CIR/Dialect/IR/CIROpsEnums.h"
@@ -312,22 +313,125 @@ buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc,
312313 return ternary.getResult ();
313314}
314315
315- static mlir::Value lowerComplexDiv (LoweringPreparePass &pass,
316- CIRBaseBuilderTy &builder,
317- mlir::Location loc, cir::ComplexDivOp op,
318- mlir::Value lhsReal, mlir::Value lhsImag,
319- mlir::Value rhsReal, mlir::Value rhsImag) {
316+ static mlir::Type higherPrecisionElementTypeForComplexArithmetic (
317+ mlir::MLIRContext &context, clang::ASTContext &cc,
318+ CIRBaseBuilderTy &builder, mlir::Type elementType) {
319+
320+ auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type {
321+ if (mlir::isa<cir::FP16Type>(type))
322+ return cir::SingleType::get (&context);
323+
324+ if (mlir::isa<cir::SingleType>(type) || mlir::isa<cir::BF16Type>(type))
325+ return cir::DoubleType::get (&context);
326+
327+ if (mlir::isa<cir::DoubleType>(type))
328+ return cir::LongDoubleType::get (&context, type);
329+
330+ return type;
331+ };
332+
333+ auto getFloatTypeSemantics =
334+ [&cc](mlir::Type type) -> const llvm::fltSemantics & {
335+ const clang::TargetInfo &info = cc.getTargetInfo ();
336+ if (mlir::isa<cir::FP16Type>(type))
337+ return info.getHalfFormat ();
338+
339+ if (mlir::isa<cir::BF16Type>(type))
340+ return info.getBFloat16Format ();
341+
342+ if (mlir::isa<cir::SingleType>(type))
343+ return info.getFloatFormat ();
344+
345+ if (mlir::isa<cir::DoubleType>(type))
346+ return info.getDoubleFormat ();
347+
348+ if (mlir::isa<cir::LongDoubleType>(type)) {
349+ if (cc.getLangOpts ().OpenMP && cc.getLangOpts ().OpenMPIsTargetDevice )
350+ llvm_unreachable (" NYI Float type semantics with OpenMP" );
351+ return info.getLongDoubleFormat ();
352+ }
353+
354+ if (mlir::isa<cir::FP128Type>(type)) {
355+ if (cc.getLangOpts ().OpenMP && cc.getLangOpts ().OpenMPIsTargetDevice )
356+ llvm_unreachable (" NYI Float type semantics with OpenMP" );
357+ return info.getFloat128Format ();
358+ }
359+
360+ assert (false && " Unsupported float type semantics" );
361+ };
362+
363+ const mlir::Type higherElementType = getHigherPrecisionFPType (elementType);
364+ const llvm::fltSemantics &elementTypeSemantics =
365+ getFloatTypeSemantics (elementType);
366+ const llvm::fltSemantics &higherElementTypeSemantics =
367+ getFloatTypeSemantics (higherElementType);
368+
369+ // Check that the promoted type can handle the intermediate values without
370+ // overflowing. This can be interpreted as:
371+ // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <=
372+ // LargerType.LargestFiniteVal.
373+ // In terms of exponent it gives this formula:
374+ // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal
375+ // doubles the exponent of SmallerType.LargestFiniteVal)
376+ if (llvm::APFloat::semanticsMaxExponent (elementTypeSemantics) * 2 + 1 <=
377+ llvm::APFloat::semanticsMaxExponent (higherElementTypeSemantics)) {
378+ return higherElementType;
379+ }
380+
381+ // The intermediate values can't be represented in the promoted type
382+ // without overflowing.
383+ return {};
384+ }
385+
386+ static mlir::Value
387+ lowerComplexDiv (LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
388+ mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal,
389+ mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag,
390+ mlir::MLIRContext &mlirCx, clang::ASTContext &cc) {
320391 cir::ComplexType complexTy = op.getType ();
321392 if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType ())) {
322393 cir::ComplexRangeKind range = op.getRange ();
323- if (range == cir::ComplexRangeKind::Improved ||
324- (range == cir::ComplexRangeKind::Promoted && !op.getPromoted ()))
394+ if (range == cir::ComplexRangeKind::Improved)
325395 return buildRangeReductionComplexDiv (builder, loc, lhsReal, lhsImag,
326396 rhsReal, rhsImag);
397+
327398 if (range == cir::ComplexRangeKind::Full)
328399 return buildComplexBinOpLibCall (pass, builder, &getComplexDivLibCallName,
329400 loc, complexTy, lhsReal, lhsImag, rhsReal,
330401 rhsImag);
402+
403+ if (range == cir::ComplexRangeKind::Promoted) {
404+ mlir::Type originalElementType = complexTy.getElementType ();
405+ mlir::Type higherPrecisionElementType =
406+ higherPrecisionElementTypeForComplexArithmetic (mlirCx, cc, builder,
407+ originalElementType);
408+
409+ if (!higherPrecisionElementType)
410+ return buildRangeReductionComplexDiv (builder, loc, lhsReal, lhsImag,
411+ rhsReal, rhsImag);
412+
413+ cir::CastKind floatingCastKind = cir::CastKind::floating;
414+ lhsReal = builder.createCast (floatingCastKind, lhsReal,
415+ higherPrecisionElementType);
416+ lhsImag = builder.createCast (floatingCastKind, lhsImag,
417+ higherPrecisionElementType);
418+ rhsReal = builder.createCast (floatingCastKind, rhsReal,
419+ higherPrecisionElementType);
420+ rhsImag = builder.createCast (floatingCastKind, rhsImag,
421+ higherPrecisionElementType);
422+
423+ mlir::Value algebraicResult = buildAlgebraicComplexDiv (
424+ builder, loc, lhsReal, lhsImag, rhsReal, rhsImag);
425+
426+ mlir::Value resultReal = builder.createComplexReal (loc, algebraicResult);
427+ mlir::Value resultImag = builder.createComplexImag (loc, algebraicResult);
428+
429+ mlir::Value finalReal =
430+ builder.createCast (floatingCastKind, resultReal, originalElementType);
431+ mlir::Value finalImag =
432+ builder.createCast (floatingCastKind, resultImag, originalElementType);
433+ return builder.createComplexCreate (loc, finalReal, finalImag);
434+ }
331435 }
332436
333437 return buildAlgebraicComplexDiv (builder, loc, lhsReal, lhsImag, rhsReal,
@@ -345,8 +449,9 @@ void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) {
345449 mlir::Value rhsReal = builder.createComplexReal (loc, rhs);
346450 mlir::Value rhsImag = builder.createComplexImag (loc, rhs);
347451
348- mlir::Value loweredResult = lowerComplexDiv (*this , builder, loc, op, lhsReal,
349- lhsImag, rhsReal, rhsImag);
452+ mlir::Value loweredResult =
453+ lowerComplexDiv (*this , builder, loc, op, lhsReal, lhsImag, rhsReal,
454+ rhsImag, getContext (), *astCtx);
350455 op.replaceAllUsesWith (loweredResult);
351456 op.erase ();
352457}
0 commit comments