Skip to content

Commit 6879b70

Browse files
committed
Address code review comments
1 parent 6613d4c commit 6879b70

File tree

4 files changed

+145
-77
lines changed

4 files changed

+145
-77
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2985,8 +2985,8 @@ def CIR_ComplexMulOp : CIR_Op<"complex.mul", [
29852985
The `cir.complex.mul` operation takes two complex numbers and returns
29862986
their product.
29872987

2988-
Range is used to select the implementation used when the operation
2989-
is lowered to the LLVM dialect. For multiplication, 'improved',
2988+
The `range` attribute is used to select the algorithm used when the
2989+
operation is lowered to the LLVM dialect. For multiplication, 'improved',
29902990
'promoted', and 'basic' are all handled equivalently, producing the
29912991
algebraic formula with no special handling for NaN value. If 'full' is
29922992
used, a runtime-library function is called if one of the intermediate
@@ -3019,15 +3019,19 @@ def CIR_ComplexDivOp : CIR_Op<"complex.div", [
30193019
let summary = "Complex division";
30203020
let description = [{
30213021
The `cir.complex.div` operation takes two complex numbers and returns
3022-
their division.
3023-
3024-
Range is used to select the implementation used when the operation
3025-
is lowered to the LLVM dialect. For division, 'improved' and
3026-
'promoted' are all handled equivalently, producing the
3027-
Smith's algorithms for Complex division. If 'full' is used,
3028-
a runtime-library function is called if one of the intermediate
3029-
calculations produced a NaN value, and for 'basic' algebraic formula with
3030-
no special handling for NaN value will be used.
3022+
their quotient.
3023+
3024+
The `range` attribute is used to select the algorithm used when
3025+
the operation is lowered to the LLVM dialect. For division, 'improved'
3026+
producing the Smith's algorithms for Complex division with no special
3027+
handling for NaN values. If 'promoted' is used, the values are promoted
3028+
to a higher precision type, if possible, and the calculation is performed
3029+
using the algebraic formula. We only fall back on Smith's algorithm when
3030+
the target does not support a higher precision type. Also, this only
3031+
applies to floating-point types with no special handling for NaN values.
3032+
If 'full' is used, a runtime-library function is called if one of the
3033+
intermediate calculations produced a NaN value. and for 'basic' algebraic
3034+
formula with no special handling for the NaN value will be used.
30313035

30323036
Example:
30333037

@@ -3040,8 +3044,7 @@ def CIR_ComplexDivOp : CIR_Op<"complex.div", [
30403044
let arguments = (ins
30413045
CIR_ComplexType:$lhs,
30423046
CIR_ComplexType:$rhs,
3043-
CIR_ComplexRangeKind:$range,
3044-
UnitAttr:$promoted
3047+
CIR_ComplexRangeKind:$range
30453048
);
30463049

30473050
let results = (outs CIR_ComplexType:$result);

clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ namespace {
1010
class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
1111
CIRGenFunction &cgf;
1212
CIRGenBuilderTy &builder;
13-
bool fpHasBeenPromoted = false;
1413

1514
public:
1615
explicit ComplexExprEmitter(CIRGenFunction &cgf)
@@ -131,43 +130,9 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
131130
mlir::Value emitBinMul(const BinOpInfo &op);
132131
mlir::Value emitBinDiv(const BinOpInfo &op);
133132

134-
QualType higherPrecisionTypeForComplexArithmetic(QualType elementType,
135-
bool isDivOpCode) {
136-
ASTContext &astContext = cgf.getContext();
137-
const QualType higherElementType =
138-
astContext.GetHigherPrecisionFPType(elementType);
139-
const llvm::fltSemantics &elementTypeSemantics =
140-
astContext.getFloatTypeSemantics(elementType);
141-
const llvm::fltSemantics &higherElementTypeSemantics =
142-
astContext.getFloatTypeSemantics(higherElementType);
143-
144-
// Check that the promoted type can handle the intermediate values without
145-
// overflowing. This can be interpreted as:
146-
// (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <=
147-
// LargerType.LargestFiniteVal.
148-
// In terms of exponent it gives this formula:
149-
// (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal
150-
// doubles the exponent of SmallerType.LargestFiniteVal)
151-
if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <=
152-
llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) {
153-
fpHasBeenPromoted = true;
154-
return astContext.getComplexType(higherElementType);
155-
}
156-
157-
// The intermediate values can't be represented in the promoted type
158-
// without overflowing.
159-
return QualType();
160-
}
161-
162133
QualType getPromotionType(QualType ty, bool isDivOpCode = false) {
163134
if (auto *complexTy = ty->getAs<ComplexType>()) {
164135
QualType elementTy = complexTy->getElementType();
165-
if (isDivOpCode && elementTy->isFloatingType() &&
166-
cgf.getLangOpts().getComplexRange() ==
167-
LangOptions::ComplexRangeKind::CX_Promoted) {
168-
return higherPrecisionTypeForComplexArithmetic(elementTy, isDivOpCode);
169-
}
170-
171136
if (elementTy.UseExcessPrecision(cgf.getContext()))
172137
return cgf.getContext().getComplexType(cgf.getContext().FloatTy);
173138
}
@@ -896,8 +861,7 @@ mlir::Value ComplexExprEmitter::emitBinDiv(const BinOpInfo &op) {
896861
mlir::isa<cir::ComplexType>(op.rhs.getType())) {
897862
cir::ComplexRangeKind rangeKind =
898863
getComplexRangeAttr(op.fpFeatures.getComplexRange());
899-
return builder.create<cir::ComplexDivOp>(op.loc, op.lhs, op.rhs, rangeKind,
900-
fpHasBeenPromoted);
864+
return builder.create<cir::ComplexDivOp>(op.loc, op.lhs, op.rhs, rangeKind);
901865
}
902866

903867
cgf.cgm.errorNYI("ComplexExprEmitter::emitBinMu between Complex & Scalar");

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

Lines changed: 114 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "PassDetail.h"
1010
#include "clang/AST/ASTContext.h"
11+
#include "clang/Basic/TargetInfo.h"
1112
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
1213
#include "clang/CIR/Dialect/IR/CIRDialect.h"
1314
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
@@ -312,22 +313,125 @@ buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc,
312313
return ternary.getResult();
313314
}
314315

315-
static mlir::Value lowerComplexDiv(LoweringPreparePass &pass,
316-
CIRBaseBuilderTy &builder,
317-
mlir::Location loc, cir::ComplexDivOp op,
318-
mlir::Value lhsReal, mlir::Value lhsImag,
319-
mlir::Value rhsReal, mlir::Value rhsImag) {
316+
static mlir::Type higherPrecisionElementTypeForComplexArithmetic(
317+
mlir::MLIRContext &context, clang::ASTContext &cc,
318+
CIRBaseBuilderTy &builder, mlir::Type elementType) {
319+
320+
auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type {
321+
if (mlir::isa<cir::FP16Type>(type))
322+
return cir::SingleType::get(&context);
323+
324+
if (mlir::isa<cir::SingleType>(type) || mlir::isa<cir::BF16Type>(type))
325+
return cir::DoubleType::get(&context);
326+
327+
if (mlir::isa<cir::DoubleType>(type))
328+
return cir::LongDoubleType::get(&context, type);
329+
330+
return type;
331+
};
332+
333+
auto getFloatTypeSemantics =
334+
[&cc](mlir::Type type) -> const llvm::fltSemantics & {
335+
const clang::TargetInfo &info = cc.getTargetInfo();
336+
if (mlir::isa<cir::FP16Type>(type))
337+
return info.getHalfFormat();
338+
339+
if (mlir::isa<cir::BF16Type>(type))
340+
return info.getBFloat16Format();
341+
342+
if (mlir::isa<cir::SingleType>(type))
343+
return info.getFloatFormat();
344+
345+
if (mlir::isa<cir::DoubleType>(type))
346+
return info.getDoubleFormat();
347+
348+
if (mlir::isa<cir::LongDoubleType>(type)) {
349+
if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
350+
llvm_unreachable("NYI Float type semantics with OpenMP");
351+
return info.getLongDoubleFormat();
352+
}
353+
354+
if (mlir::isa<cir::FP128Type>(type)) {
355+
if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
356+
llvm_unreachable("NYI Float type semantics with OpenMP");
357+
return info.getFloat128Format();
358+
}
359+
360+
assert(false && "Unsupported float type semantics");
361+
};
362+
363+
const mlir::Type higherElementType = getHigherPrecisionFPType(elementType);
364+
const llvm::fltSemantics &elementTypeSemantics =
365+
getFloatTypeSemantics(elementType);
366+
const llvm::fltSemantics &higherElementTypeSemantics =
367+
getFloatTypeSemantics(higherElementType);
368+
369+
// Check that the promoted type can handle the intermediate values without
370+
// overflowing. This can be interpreted as:
371+
// (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <=
372+
// LargerType.LargestFiniteVal.
373+
// In terms of exponent it gives this formula:
374+
// (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal
375+
// doubles the exponent of SmallerType.LargestFiniteVal)
376+
if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <=
377+
llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) {
378+
return higherElementType;
379+
}
380+
381+
// The intermediate values can't be represented in the promoted type
382+
// without overflowing.
383+
return {};
384+
}
385+
386+
static mlir::Value
387+
lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
388+
mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal,
389+
mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag,
390+
mlir::MLIRContext &mlirCx, clang::ASTContext &cc) {
320391
cir::ComplexType complexTy = op.getType();
321392
if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType())) {
322393
cir::ComplexRangeKind range = op.getRange();
323-
if (range == cir::ComplexRangeKind::Improved ||
324-
(range == cir::ComplexRangeKind::Promoted && !op.getPromoted()))
394+
if (range == cir::ComplexRangeKind::Improved)
325395
return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
326396
rhsReal, rhsImag);
397+
327398
if (range == cir::ComplexRangeKind::Full)
328399
return buildComplexBinOpLibCall(pass, builder, &getComplexDivLibCallName,
329400
loc, complexTy, lhsReal, lhsImag, rhsReal,
330401
rhsImag);
402+
403+
if (range == cir::ComplexRangeKind::Promoted) {
404+
mlir::Type originalElementType = complexTy.getElementType();
405+
mlir::Type higherPrecisionElementType =
406+
higherPrecisionElementTypeForComplexArithmetic(mlirCx, cc, builder,
407+
originalElementType);
408+
409+
if (!higherPrecisionElementType)
410+
return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
411+
rhsReal, rhsImag);
412+
413+
cir::CastKind floatingCastKind = cir::CastKind::floating;
414+
lhsReal = builder.createCast(floatingCastKind, lhsReal,
415+
higherPrecisionElementType);
416+
lhsImag = builder.createCast(floatingCastKind, lhsImag,
417+
higherPrecisionElementType);
418+
rhsReal = builder.createCast(floatingCastKind, rhsReal,
419+
higherPrecisionElementType);
420+
rhsImag = builder.createCast(floatingCastKind, rhsImag,
421+
higherPrecisionElementType);
422+
423+
mlir::Value algebraicResult = buildAlgebraicComplexDiv(
424+
builder, loc, lhsReal, lhsImag, rhsReal, rhsImag);
425+
426+
mlir::Value resultReal = builder.createComplexReal(loc, algebraicResult);
427+
mlir::Value resultImag = builder.createComplexImag(loc, algebraicResult);
428+
429+
mlir::Value finalReal =
430+
builder.createCast(floatingCastKind, resultReal, originalElementType);
431+
mlir::Value finalImag =
432+
builder.createCast(floatingCastKind, resultImag, originalElementType);
433+
return builder.createComplexCreate(loc, finalReal, finalImag);
434+
}
331435
}
332436

333437
return buildAlgebraicComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal,
@@ -345,8 +449,9 @@ void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) {
345449
mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
346450
mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
347451

348-
mlir::Value loweredResult = lowerComplexDiv(*this, builder, loc, op, lhsReal,
349-
lhsImag, rhsReal, rhsImag);
452+
mlir::Value loweredResult =
453+
lowerComplexDiv(*this, builder, loc, op, lhsReal, lhsImag, rhsReal,
454+
rhsImag, getContext(), *astCtx);
350455
op.replaceAllUsesWith(loweredResult);
351456
op.erase();
352457
}

clang/test/CIR/CodeGen/complex-mul-div.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -476,25 +476,21 @@ void foo3() {
476476
// OGCG-IMPROVED: store float %[[RESULT_REAL]], ptr %[[C_REAL_PTR]], align 4
477477
// OGCG-IMPROVED: store float %[[RESULT_IMAG]], ptr %[[C_IMAG_PTR]], align 4
478478

479-
// CIR-BEFORE-PROMOTED: %{{.*}} = cir.complex.div {{.*}}, {{.*}} range(promoted) : !cir.complex<!cir.double>
479+
// CIR-BEFORE-PROMOTED: %{{.*}} = cir.complex.div {{.*}}, {{.*}} range(promoted) : !cir.complex<!cir.float>
480480

481481
// LLVM-PROMOTED: %[[A_ADDR:.*]] = alloca { float, float }, i64 1, align 4
482482
// LLVM-PROMOTED: %[[B_ADDR:.*]] = alloca { float, float }, i64 1, align 4
483483
// LLVM-PROMOTED: %[[C_ADDR:.*]] = alloca { float, float }, i64 1, align 4
484484
// LLVM-PROMOTED: %[[TMP_A:.*]] = load { float, float }, ptr %[[A_ADDR]], align 4
485+
// LLVM-PROMOTED: %[[TMP_B:.*]] = load { float, float }, ptr %[[B_ADDR]], align 4
485486
// LLVM-PROMOTED: %[[A_REAL:.*]] = extractvalue { float, float } %[[TMP_A]], 0
486487
// LLVM-PROMOTED: %[[A_IMAG:.*]] = extractvalue { float, float } %[[TMP_A]], 1
487-
// LLVM-PROMOTED: %[[A_REAL_F64:.*]] = fpext float %[[A_REAL]] to double
488-
// LLVM-PROMOTED: %[[A_IMAG_F64:.*]] = fpext float %[[A_IMAG]] to double
489-
// LLVM-PROMOTED: %[[TMP_A_CF64:.*]] = insertvalue { double, double } {{.*}}, double %[[A_REAL_F64]], 0
490-
// LLVM-PROMOTED: %[[A_CF64:.*]] = insertvalue { double, double } %[[TMP_A_CF64]], double %[[A_IMAG_F64]], 1
491-
// LLVM-PROMOTED: %[[TMP_B:.*]] = load { float, float }, ptr %[[B_ADDR]], align 4
492488
// LLVM-PROMOTED: %[[B_REAL:.*]] = extractvalue { float, float } %[[TMP_B]], 0
493489
// LLVM-PROMOTED: %[[B_IMAG:.*]] = extractvalue { float, float } %[[TMP_B]], 1
490+
// LLVM-PROMOTED: %[[A_REAL_F64:.*]] = fpext float %[[A_REAL]] to double
491+
// LLVM-PROMOTED: %[[A_IMAG_F64:.*]] = fpext float %[[A_IMAG]] to double
494492
// LLVM-PROMOTED: %[[B_REAL_F64:.*]] = fpext float %[[B_REAL]] to double
495493
// LLVM-PROMOTED: %[[B_IMAG_F64:.*]] = fpext float %[[B_IMAG]] to double
496-
// LLVM-PROMOTED: %[[TMP_B_CF64:.*]] = insertvalue { double, double } {{.*}}, double %[[B_REAL_F64]], 0
497-
// LLVM-PROMOTED: %[[B_CF64:.*]] = insertvalue { double, double } %[[TMP_B_CF64]], double %[[B_IMAG_F64]], 1
498494
// LLVM-PROMOTED: %[[MUL_AR_BR:.*]] = fmul double %[[A_REAL_F64]], %[[B_REAL_F64]]
499495
// LLVM-PROMOTED: %[[MUL_AI_BI:.*]] = fmul double %[[A_IMAG_F64]], %[[B_IMAG_F64]]
500496
// LLVM-PROMOTED: %[[MUL_BR_BR:.*]] = fmul double %[[B_REAL_F64]], %[[B_REAL_F64]]
@@ -503,16 +499,16 @@ void foo3() {
503499
// LLVM-PROMOTED: %[[ADD_BRBR_BIBI:.*]] = fadd double %[[MUL_BR_BR]], %[[MUL_BI_BI]]
504500
// LLVM-PROMOTED: %[[RESULT_REAL:.*]] = fdiv double %[[ADD_ARBR_AIBI]], %[[ADD_BRBR_BIBI]]
505501
// LLVM-PROMOTED: %[[MUL_AI_BR:.*]] = fmul double %[[A_IMAG_F64]], %[[B_REAL_F64]]
506-
// LLVM-PROMOTED: %[[MUL_AR_BI:.*]] = fmul double %[[A_REAL_F64]], %[[B_IMAG_F64]]
507-
// LLVM-PROMOTED: %[[SUB_AIBR_ARBI:.*]] = fsub double %[[MUL_AI_BR]], %[[MUL_AR_BI]]
508-
// LLVM-PROMOTED: %[[RESULT_IMAG:.*]] = fdiv double %[[SUB_AIBR_ARBI]], %23
509-
// LLVM-PROMOTED: %[[TMP_RESULT_CF64:.*]] = insertvalue { double, double } {{.*}}, double %[[RESULT_REAL]], 0
510-
// LLVM-PROMOTED: %[[RESULT_CF64:.*]] = insertvalue { double, double } %[[TMP_RESULT_CF64]], double %[[RESULT_IMAG]], 1
502+
// LLVM-PROMOTED: %[[MUL_AR_BR:.*]] = fmul double %[[A_REAL_F64]], %[[B_IMAG_F64]]
503+
// LLVM-PROMOTED: %[[SUB_AIBR_ARBI:.*]] = fsub double %[[MUL_AI_BR]], %[[MUL_AR_BR]]
504+
// LLVM-PROMOTED: %[[RESULT_IMAG:.*]] = fdiv double %[[SUB_AIBR_ARBI]], %[[ADD_BRBR_BIBI]]
505+
// LLVM-PROMOTED: %[[TMP_RESULT_F64:.*]] = insertvalue { double, double } {{.*}}, double %[[RESULT_REAL]], 0
506+
// LLVM-PROMOTED: %[[RESULT_F64:.*]] = insertvalue { double, double } %[[TMP_RESULT_F64]], double %[[RESULT_IMAG]], 1
511507
// LLVM-PROMOTED: %[[RESULT_REAL_F32:.*]] = fptrunc double %[[RESULT_REAL]] to float
512508
// LLVM-PROMOTED: %[[RESULT_IMAG_F32:.*]] = fptrunc double %[[RESULT_IMAG]] to float
513-
// LLVM-PROMOTED: %[[TMP_RESULT_CF32:.*]] = insertvalue { float, float } {{.*}}, float %[[RESULT_REAL_F32]], 0
514-
// LLVM-PROMOTED: %[[RESULT_CF32:.*]] = insertvalue { float, float } %[[TMP_RESULT_CF32]], float %[[RESULT_IMAG_F32]], 1
515-
// LLVM-PROMOTED: store { float, float } %[[RESULT_CF32]], ptr %[[C_ADDR]], align 4
509+
// LLVM-PROMOTED: %[[TMP_RESULT_F32:.*]] = insertvalue { float, float } {{.*}}, float %[[RESULT_REAL_F32]], 0
510+
// LLVM-PROMOTED: %[[RESULT_F32:.*]] = insertvalue { float, float } %[[TMP_RESULT_F32]], float %[[RESULT_IMAG_F32]], 1
511+
// LLVM-PROMOTED: store { float, float } %[[RESULT_F32]], ptr %[[C_ADDR]], align 4
516512

517513
// OGCG-PROMOTED: %[[A_ADDR:.*]] = alloca { float, float }, align 4
518514
// OGCG-PROMOTED: %[[B_ADDR:.*]] = alloca { float, float }, align 4
@@ -537,9 +533,9 @@ void foo3() {
537533
// OGCG-PROMOTED: %[[ADD_BRBR_BIBI:.*]] = fadd double %[[MUL_BR_BR]], %[[MUL_BI_BI]]
538534
// OGCG-PROMOTED: %[[MUL_AI_BR:.*]] = fmul double %[[A_IMAG_F64]], %[[B_REAL_F64]]
539535
// OGCG-PROMOTED: %[[MUL_AR_BI:.*]] = fmul double %[[A_REAL_F64]], %[[B_IMAG_F64]]
540-
// OGCG-PROMOTED: %[[SUB_AIBR_BRBI:.*]] = fsub double %[[MUL_AI_BR]], %[[MUL_AR_BI]]
536+
// OGCG-PROMOTED: %[[SUB_AIBR_ARBI:.*]] = fsub double %[[MUL_AI_BR]], %[[MUL_AR_BI]]
541537
// OGCG-PROMOTED: %[[RESULT_REAL:.*]] = fdiv double %[[ADD_ARBR_AIBI]], %[[ADD_BRBR_BIBI]]
542-
// OGCG-PROMOTED: %[[RESULT_IMAG:.*]] = fdiv double %[[SUB_AIBR_BRBI]], %[[ADD_BRBR_BIBI]]
538+
// OGCG-PROMOTED: %[[RESULT_IMAG:.*]] = fdiv double %[[SUB_AIBR_ARBI]], %[[ADD_BRBR_BIBI]]
543539
// OGCG-PROMOTED: %[[UNPROMOTION_RESULT_REAL:.*]] = fptrunc double %[[RESULT_REAL]] to float
544540
// OGCG-PROMOTED: %[[UNPROMOTION_RESULT_IMAG:.*]] = fptrunc double %[[RESULT_IMAG]] to float
545541
// OGCG-PROMOTED: %[[C_REAL_PTR:.*]] = getelementptr inbounds nuw { float, float }, ptr %[[C_ADDR]], i32 0, i32 0

0 commit comments

Comments
 (0)