Skip to content

Commit 6613d4c

Browse files
committed
[CIR] Upstream DivOp for ComplexType
1 parent 91418ec commit 6613d4c

File tree

5 files changed

+602
-12
lines changed

5 files changed

+602
-12
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2966,7 +2966,7 @@ def CIR_ComplexSubOp : CIR_Op<"complex.sub", [
29662966
}
29672967

29682968
//===----------------------------------------------------------------------===//
2969-
// ComplexMulOp
2969+
// ComplexMulOp & ComplexDivOp
29702970
//===----------------------------------------------------------------------===//
29712971

29722972
def CIR_ComplexRangeKind : CIR_I32EnumAttr<
@@ -3013,6 +3013,44 @@ def CIR_ComplexMulOp : CIR_Op<"complex.mul", [
30133013
}];
30143014
}
30153015

3016+
def CIR_ComplexDivOp : CIR_Op<"complex.div", [
3017+
Pure, SameOperandsAndResultType
3018+
]> {
3019+
let summary = "Complex division";
3020+
let description = [{
3021+
The `cir.complex.div` operation takes two complex numbers and returns
3022+
their division.
3023+
3024+
Range is used to select the implementation used when the operation
3025+
is lowered to the LLVM dialect. For division, 'improved' and
3026+
'promoted' are all handled equivalently, producing the
3027+
Smith's algorithms for Complex division. If 'full' is used,
3028+
a runtime-library function is called if one of the intermediate
3029+
calculations produced a NaN value, and for 'basic' algebraic formula with
3030+
no special handling for NaN value will be used.
3031+
3032+
Example:
3033+
3034+
```mlir
3035+
%2 = cir.complex.div %0, %1 range(basic) : !cir.complex<!cir.float>
3036+
%2 = cir.complex.div %0, %1 range(full) : !cir.complex<!cir.float>
3037+
```
3038+
}];
3039+
3040+
let arguments = (ins
3041+
CIR_ComplexType:$lhs,
3042+
CIR_ComplexType:$rhs,
3043+
CIR_ComplexRangeKind:$range,
3044+
UnitAttr:$promoted
3045+
);
3046+
3047+
let results = (outs CIR_ComplexType:$result);
3048+
3049+
let assemblyFormat = [{
3050+
$lhs `,` $rhs `range` `(` $range `)` `:` qualified(type($result)) attr-dict
3051+
}];
3052+
}
3053+
30163054
//===----------------------------------------------------------------------===//
30173055
// Bit Manipulation Operations
30183056
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ namespace {
1010
class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
1111
CIRGenFunction &cgf;
1212
CIRGenBuilderTy &builder;
13+
bool fpHasBeenPromoted = false;
1314

1415
public:
1516
explicit ComplexExprEmitter(CIRGenFunction &cgf)
@@ -128,15 +129,43 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
128129
mlir::Value emitBinAdd(const BinOpInfo &op);
129130
mlir::Value emitBinSub(const BinOpInfo &op);
130131
mlir::Value emitBinMul(const BinOpInfo &op);
132+
mlir::Value emitBinDiv(const BinOpInfo &op);
133+
134+
QualType higherPrecisionTypeForComplexArithmetic(QualType elementType,
135+
bool isDivOpCode) {
136+
ASTContext &astContext = cgf.getContext();
137+
const QualType higherElementType =
138+
astContext.GetHigherPrecisionFPType(elementType);
139+
const llvm::fltSemantics &elementTypeSemantics =
140+
astContext.getFloatTypeSemantics(elementType);
141+
const llvm::fltSemantics &higherElementTypeSemantics =
142+
astContext.getFloatTypeSemantics(higherElementType);
143+
144+
// Check that the promoted type can handle the intermediate values without
145+
// overflowing. This can be interpreted as:
146+
// (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <=
147+
// LargerType.LargestFiniteVal.
148+
// In terms of exponent it gives this formula:
149+
// (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal
150+
// doubles the exponent of SmallerType.LargestFiniteVal)
151+
if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <=
152+
llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) {
153+
fpHasBeenPromoted = true;
154+
return astContext.getComplexType(higherElementType);
155+
}
156+
157+
// The intermediate values can't be represented in the promoted type
158+
// without overflowing.
159+
return QualType();
160+
}
131161

132162
QualType getPromotionType(QualType ty, bool isDivOpCode = false) {
133163
if (auto *complexTy = ty->getAs<ComplexType>()) {
134164
QualType elementTy = complexTy->getElementType();
135165
if (isDivOpCode && elementTy->isFloatingType() &&
136166
cgf.getLangOpts().getComplexRange() ==
137167
LangOptions::ComplexRangeKind::CX_Promoted) {
138-
cgf.cgm.errorNYI("HigherPrecisionTypeForComplexArithmetic");
139-
return QualType();
168+
return higherPrecisionTypeForComplexArithmetic(elementTy, isDivOpCode);
140169
}
141170

142171
if (elementTy.UseExcessPrecision(cgf.getContext()))
@@ -154,13 +183,14 @@ class ComplexExprEmitter : public StmtVisitor<ComplexExprEmitter, mlir::Value> {
154183
e->getType(), e->getOpcode() == BinaryOperatorKind::BO_Div); \
155184
mlir::Value result = emitBin##OP(emitBinOps(e, promotionTy)); \
156185
if (!promotionTy.isNull()) \
157-
cgf.cgm.errorNYI("Binop emitUnPromotedValue"); \
186+
result = cgf.emitUnPromotedValue(result, e->getType()); \
158187
return result; \
159188
}
160189

161190
HANDLEBINOP(Add)
162191
HANDLEBINOP(Sub)
163192
HANDLEBINOP(Mul)
193+
HANDLEBINOP(Div)
164194
#undef HANDLEBINOP
165195

166196
// Compound assignments.
@@ -858,6 +888,22 @@ mlir::Value ComplexExprEmitter::emitBinMul(const BinOpInfo &op) {
858888
return builder.createComplexCreate(op.loc, newReal, newImag);
859889
}
860890

891+
mlir::Value ComplexExprEmitter::emitBinDiv(const BinOpInfo &op) {
892+
assert(!cir::MissingFeatures::fastMathFlags());
893+
assert(!cir::MissingFeatures::cgFPOptionsRAII());
894+
895+
if (mlir::isa<cir::ComplexType>(op.lhs.getType()) &&
896+
mlir::isa<cir::ComplexType>(op.rhs.getType())) {
897+
cir::ComplexRangeKind rangeKind =
898+
getComplexRangeAttr(op.fpFeatures.getComplexRange());
899+
return builder.create<cir::ComplexDivOp>(op.loc, op.lhs, op.rhs, rangeKind,
900+
fpHasBeenPromoted);
901+
}
902+
903+
cgf.cgm.errorNYI("ComplexExprEmitter::emitBinMu between Complex & Scalar");
904+
return {};
905+
}
906+
861907
LValue CIRGenFunction::emitComplexAssignmentLValue(const BinaryOperator *e) {
862908
assert(e->getOpcode() == BO_Assign && "Expected assign op");
863909

@@ -954,6 +1000,14 @@ mlir::Value CIRGenFunction::emitPromotedValue(mlir::Value result,
9541000
convertType(promotionType));
9551001
}
9561002

1003+
mlir::Value CIRGenFunction::emitUnPromotedValue(mlir::Value result,
1004+
QualType unPromotionType) {
1005+
assert(!mlir::cast<cir::ComplexType>(result.getType()).isIntegerComplex() &&
1006+
"integral complex will never be promoted");
1007+
return builder.createCast(cir::CastKind::float_complex, result,
1008+
convertType(unPromotionType));
1009+
}
1010+
9571011
LValue CIRGenFunction::emitScalarCompoundAssignWithComplex(
9581012
const CompoundAssignOperator *e, mlir::Value &result) {
9591013
CompoundFunc op = getComplexOp(e->getOpcode());

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,6 +1302,8 @@ class CIRGenFunction : public CIRGenTypeCache {
13021302

13031303
LValue emitUnaryOpLValue(const clang::UnaryOperator *e);
13041304

1305+
mlir::Value emitUnPromotedValue(mlir::Value result, QualType unPromotionType);
1306+
13051307
/// Emit a reached-unreachable diagnostic if \p loc is valid and runtime
13061308
/// checking is enabled. Otherwise, just emit an unreachable instruction.
13071309
/// \p createNewBlock indicates whether to create a new block for the IR

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

Lines changed: 174 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#include "PassDetail.h"
1010
#include "clang/AST/ASTContext.h"
11-
#include "clang/AST/CharUnits.h"
1211
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
1312
#include "clang/CIR/Dialect/IR/CIRDialect.h"
1413
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
@@ -27,6 +26,7 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
2726

2827
void runOnOp(mlir::Operation *op);
2928
void lowerCastOp(cir::CastOp op);
29+
void lowerComplexDivOp(cir::ComplexDivOp op);
3030
void lowerComplexMulOp(cir::ComplexMulOp op);
3131
void lowerUnaryOp(cir::UnaryOp op);
3232
void lowerArrayDtor(cir::ArrayDtor op);
@@ -181,6 +181,176 @@ static mlir::Value buildComplexBinOpLibCall(
181181
return call.getResult();
182182
}
183183

184+
static llvm::StringRef
185+
getComplexDivLibCallName(llvm::APFloat::Semantics semantics) {
186+
switch (semantics) {
187+
case llvm::APFloat::S_IEEEhalf:
188+
return "__divhc3";
189+
case llvm::APFloat::S_IEEEsingle:
190+
return "__divsc3";
191+
case llvm::APFloat::S_IEEEdouble:
192+
return "__divdc3";
193+
case llvm::APFloat::S_PPCDoubleDouble:
194+
return "__divtc3";
195+
case llvm::APFloat::S_x87DoubleExtended:
196+
return "__divxc3";
197+
case llvm::APFloat::S_IEEEquad:
198+
return "__divtc3";
199+
default:
200+
llvm_unreachable("unsupported floating point type");
201+
}
202+
}
203+
204+
static mlir::Value
205+
buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc,
206+
mlir::Value lhsReal, mlir::Value lhsImag,
207+
mlir::Value rhsReal, mlir::Value rhsImag) {
208+
// (a+bi) / (c+di) = ((ac+bd)/(cc+dd)) + ((bc-ad)/(cc+dd))i
209+
mlir::Value &a = lhsReal;
210+
mlir::Value &b = lhsImag;
211+
mlir::Value &c = rhsReal;
212+
mlir::Value &d = rhsImag;
213+
214+
mlir::Value ac = builder.createBinop(loc, a, cir::BinOpKind::Mul, c); // a*c
215+
mlir::Value bd = builder.createBinop(loc, b, cir::BinOpKind::Mul, d); // b*d
216+
mlir::Value cc = builder.createBinop(loc, c, cir::BinOpKind::Mul, c); // c*c
217+
mlir::Value dd = builder.createBinop(loc, d, cir::BinOpKind::Mul, d); // d*d
218+
mlir::Value acbd =
219+
builder.createBinop(loc, ac, cir::BinOpKind::Add, bd); // ac+bd
220+
mlir::Value ccdd =
221+
builder.createBinop(loc, cc, cir::BinOpKind::Add, dd); // cc+dd
222+
mlir::Value resultReal =
223+
builder.createBinop(loc, acbd, cir::BinOpKind::Div, ccdd);
224+
225+
mlir::Value bc = builder.createBinop(loc, b, cir::BinOpKind::Mul, c); // b*c
226+
mlir::Value ad = builder.createBinop(loc, a, cir::BinOpKind::Mul, d); // a*d
227+
mlir::Value bcad =
228+
builder.createBinop(loc, bc, cir::BinOpKind::Sub, ad); // bc-ad
229+
mlir::Value resultImag =
230+
builder.createBinop(loc, bcad, cir::BinOpKind::Div, ccdd);
231+
return builder.createComplexCreate(loc, resultReal, resultImag);
232+
}
233+
234+
static mlir::Value
235+
buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc,
236+
mlir::Value lhsReal, mlir::Value lhsImag,
237+
mlir::Value rhsReal, mlir::Value rhsImag) {
238+
// Implements Smith's algorithm for complex division.
239+
// SMITH, R. L. Algorithm 116: Complex division. Commun. ACM 5, 8 (1962).
240+
241+
// Let:
242+
// - lhs := a+bi
243+
// - rhs := c+di
244+
// - result := lhs / rhs = e+fi
245+
//
246+
// The algorithm pseudocode looks like follows:
247+
// if fabs(c) >= fabs(d):
248+
// r := d / c
249+
// tmp := c + r*d
250+
// e = (a + b*r) / tmp
251+
// f = (b - a*r) / tmp
252+
// else:
253+
// r := c / d
254+
// tmp := d + r*c
255+
// e = (a*r + b) / tmp
256+
// f = (b*r - a) / tmp
257+
258+
mlir::Value &a = lhsReal;
259+
mlir::Value &b = lhsImag;
260+
mlir::Value &c = rhsReal;
261+
mlir::Value &d = rhsImag;
262+
263+
auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
264+
mlir::Value r = builder.createBinop(loc, d, cir::BinOpKind::Div,
265+
c); // r := d / c
266+
mlir::Value rd = builder.createBinop(loc, r, cir::BinOpKind::Mul, d); // r*d
267+
mlir::Value tmp = builder.createBinop(loc, c, cir::BinOpKind::Add,
268+
rd); // tmp := c + r*d
269+
270+
mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r
271+
mlir::Value abr =
272+
builder.createBinop(loc, a, cir::BinOpKind::Add, br); // a + b*r
273+
mlir::Value e = builder.createBinop(loc, abr, cir::BinOpKind::Div, tmp);
274+
275+
mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r
276+
mlir::Value bar =
277+
builder.createBinop(loc, b, cir::BinOpKind::Sub, ar); // b - a*r
278+
mlir::Value f = builder.createBinop(loc, bar, cir::BinOpKind::Div, tmp);
279+
280+
mlir::Value result = builder.createComplexCreate(loc, e, f);
281+
builder.createYield(loc, result);
282+
};
283+
284+
auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
285+
mlir::Value r = builder.createBinop(loc, c, cir::BinOpKind::Div,
286+
d); // r := c / d
287+
mlir::Value rc = builder.createBinop(loc, r, cir::BinOpKind::Mul, c); // r*c
288+
mlir::Value tmp = builder.createBinop(loc, d, cir::BinOpKind::Add,
289+
rc); // tmp := d + r*c
290+
291+
mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r
292+
mlir::Value arb =
293+
builder.createBinop(loc, ar, cir::BinOpKind::Add, b); // a*r + b
294+
mlir::Value e = builder.createBinop(loc, arb, cir::BinOpKind::Div, tmp);
295+
296+
mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r
297+
mlir::Value bra =
298+
builder.createBinop(loc, br, cir::BinOpKind::Sub, a); // b*r - a
299+
mlir::Value f = builder.createBinop(loc, bra, cir::BinOpKind::Div, tmp);
300+
301+
mlir::Value result = builder.createComplexCreate(loc, e, f);
302+
builder.createYield(loc, result);
303+
};
304+
305+
auto cFabs = builder.create<cir::FAbsOp>(loc, c);
306+
auto dFabs = builder.create<cir::FAbsOp>(loc, d);
307+
cir::CmpOp cmpResult =
308+
builder.createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs);
309+
auto ternary = builder.create<cir::TernaryOp>(
310+
loc, cmpResult, trueBranchBuilder, falseBranchBuilder);
311+
312+
return ternary.getResult();
313+
}
314+
315+
static mlir::Value lowerComplexDiv(LoweringPreparePass &pass,
316+
CIRBaseBuilderTy &builder,
317+
mlir::Location loc, cir::ComplexDivOp op,
318+
mlir::Value lhsReal, mlir::Value lhsImag,
319+
mlir::Value rhsReal, mlir::Value rhsImag) {
320+
cir::ComplexType complexTy = op.getType();
321+
if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType())) {
322+
cir::ComplexRangeKind range = op.getRange();
323+
if (range == cir::ComplexRangeKind::Improved ||
324+
(range == cir::ComplexRangeKind::Promoted && !op.getPromoted()))
325+
return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
326+
rhsReal, rhsImag);
327+
if (range == cir::ComplexRangeKind::Full)
328+
return buildComplexBinOpLibCall(pass, builder, &getComplexDivLibCallName,
329+
loc, complexTy, lhsReal, lhsImag, rhsReal,
330+
rhsImag);
331+
}
332+
333+
return buildAlgebraicComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal,
334+
rhsImag);
335+
}
336+
337+
void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) {
338+
cir::CIRBaseBuilderTy builder(getContext());
339+
builder.setInsertionPointAfter(op);
340+
mlir::Location loc = op.getLoc();
341+
mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
342+
mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
343+
mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
344+
mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
345+
mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
346+
mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
347+
348+
mlir::Value loweredResult = lowerComplexDiv(*this, builder, loc, op, lhsReal,
349+
lhsImag, rhsReal, rhsImag);
350+
op.replaceAllUsesWith(loweredResult);
351+
op.erase();
352+
}
353+
184354
static llvm::StringRef
185355
getComplexMulLibCallName(llvm::APFloat::Semantics semantics) {
186356
switch (semantics) {
@@ -412,6 +582,8 @@ void LoweringPreparePass::runOnOp(mlir::Operation *op) {
412582
lowerArrayDtor(arrayDtor);
413583
else if (auto cast = mlir::dyn_cast<cir::CastOp>(op))
414584
lowerCastOp(cast);
585+
else if (auto complexDiv = mlir::dyn_cast<cir::ComplexDivOp>(op))
586+
lowerComplexDivOp(complexDiv);
415587
else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op))
416588
lowerComplexMulOp(complexMul);
417589
else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op))
@@ -427,7 +599,7 @@ void LoweringPreparePass::runOnOperation() {
427599

428600
op->walk([&](mlir::Operation *op) {
429601
if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
430-
cir::ComplexMulOp, cir::UnaryOp>(op))
602+
cir::ComplexMulOp, cir::ComplexDivOp, cir::UnaryOp>(op))
431603
opsToTransform.push_back(op);
432604
});
433605

0 commit comments

Comments
 (0)