15
15
#include " clang/CIR/Dialect/Passes.h"
16
16
#include " clang/CIR/MissingFeatures.h"
17
17
18
- #include < iostream>
19
18
#include < memory>
20
19
21
20
using namespace mlir ;
@@ -28,21 +27,47 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
28
27
29
28
void runOnOp (mlir::Operation *op);
30
29
void lowerCastOp (cir::CastOp op);
30
+ void lowerComplexMulOp (cir::ComplexMulOp op);
31
31
void lowerUnaryOp (cir::UnaryOp op);
32
32
void lowerArrayDtor (cir::ArrayDtor op);
33
33
void lowerArrayCtor (cir::ArrayCtor op);
34
34
35
+ cir::FuncOp buildRuntimeFunction (
36
+ mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
37
+ cir::FuncType type,
38
+ cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
39
+
35
40
// /
36
41
// / AST related
37
42
// / -----------
38
43
39
44
clang::ASTContext *astCtx;
40
45
46
+ // / Tracks current module.
47
+ mlir::ModuleOp mlirModule;
48
+
41
49
void setASTContext (clang::ASTContext *c) { astCtx = c; }
42
50
};
43
51
44
52
} // namespace
45
53
54
+ cir::FuncOp LoweringPreparePass::buildRuntimeFunction (
55
+ mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
56
+ cir::FuncType type, cir::GlobalLinkageKind linkage) {
57
+ cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom (
58
+ mlirModule, StringAttr::get (mlirModule->getContext (), name)));
59
+ if (!f) {
60
+ f = builder.create <cir::FuncOp>(loc, name, type);
61
+ f.setLinkageAttr (
62
+ cir::GlobalLinkageKindAttr::get (builder.getContext (), linkage));
63
+ mlir::SymbolTable::setSymbolVisibility (
64
+ f, mlir::SymbolTable::Visibility::Private);
65
+
66
+ assert (!cir::MissingFeatures::opFuncExtraAttrs ());
67
+ }
68
+ return f;
69
+ }
70
+
46
71
static mlir::Value lowerScalarToComplexCast (mlir::MLIRContext &ctx,
47
72
cir::CastOp op) {
48
73
cir::CIRBaseBuilderTy builder (ctx);
@@ -128,6 +153,124 @@ void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
128
153
}
129
154
}
130
155
156
+ static mlir::Value buildComplexBinOpLibCall (
157
+ LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
158
+ llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
159
+ mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
160
+ mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
161
+ cir::FPTypeInterface elementTy =
162
+ mlir::cast<cir::FPTypeInterface>(ty.getElementType ());
163
+
164
+ llvm::StringRef libFuncName = libFuncNameGetter (
165
+ llvm::APFloat::SemanticsToEnum (elementTy.getFloatSemantics ()));
166
+ llvm::SmallVector<mlir::Type, 4 > libFuncInputTypes (4 , elementTy);
167
+
168
+ cir::FuncType libFuncTy = cir::FuncType::get (libFuncInputTypes, ty);
169
+
170
+ // Insert a declaration for the runtime function to be used in Complex
171
+ // multiplication and division when needed
172
+ cir::FuncOp libFunc;
173
+ {
174
+ mlir::OpBuilder::InsertionGuard ipGuard{builder};
175
+ builder.setInsertionPointToStart (pass.mlirModule .getBody ());
176
+ libFunc = pass.buildRuntimeFunction (builder, libFuncName, loc, libFuncTy);
177
+ }
178
+
179
+ cir::CallOp call =
180
+ builder.createCallOp (loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
181
+ return call.getResult ();
182
+ }
183
+
184
+ static llvm::StringRef
185
+ getComplexMulLibCallName (llvm::APFloat::Semantics semantics) {
186
+ switch (semantics) {
187
+ case llvm::APFloat::S_IEEEhalf:
188
+ return " __mulhc3" ;
189
+ case llvm::APFloat::S_IEEEsingle:
190
+ return " __mulsc3" ;
191
+ case llvm::APFloat::S_IEEEdouble:
192
+ return " __muldc3" ;
193
+ case llvm::APFloat::S_PPCDoubleDouble:
194
+ return " __multc3" ;
195
+ case llvm::APFloat::S_x87DoubleExtended:
196
+ return " __mulxc3" ;
197
+ case llvm::APFloat::S_IEEEquad:
198
+ return " __multc3" ;
199
+ default :
200
+ llvm_unreachable (" unsupported floating point type" );
201
+ }
202
+ }
203
+
204
+ static mlir::Value lowerComplexMul (LoweringPreparePass &pass,
205
+ CIRBaseBuilderTy &builder,
206
+ mlir::Location loc, cir::ComplexMulOp op,
207
+ mlir::Value lhsReal, mlir::Value lhsImag,
208
+ mlir::Value rhsReal, mlir::Value rhsImag) {
209
+ // (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
210
+ mlir::Value resultRealLhs =
211
+ builder.createBinop (loc, lhsReal, cir::BinOpKind::Mul, rhsReal);
212
+ mlir::Value resultRealRhs =
213
+ builder.createBinop (loc, lhsImag, cir::BinOpKind::Mul, rhsImag);
214
+ mlir::Value resultImagLhs =
215
+ builder.createBinop (loc, lhsReal, cir::BinOpKind::Mul, rhsImag);
216
+ mlir::Value resultImagRhs =
217
+ builder.createBinop (loc, lhsImag, cir::BinOpKind::Mul, rhsReal);
218
+ mlir::Value resultReal = builder.createBinop (
219
+ loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs);
220
+ mlir::Value resultImag = builder.createBinop (
221
+ loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs);
222
+ mlir::Value algebraicResult =
223
+ builder.createComplexCreate (loc, resultReal, resultImag);
224
+
225
+ cir::ComplexType complexTy = op.getType ();
226
+ cir::ComplexRangeKind rangeKind = op.getRange ();
227
+ if (mlir::isa<cir::IntType>(complexTy.getElementType ()) ||
228
+ rangeKind == cir::ComplexRangeKind::Basic ||
229
+ rangeKind == cir::ComplexRangeKind::Improved ||
230
+ rangeKind == cir::ComplexRangeKind::Promoted)
231
+ return algebraicResult;
232
+
233
+ assert (!cir::MissingFeatures::fastMathFlags ());
234
+
235
+ // Check whether the real part and the imaginary part of the result are both
236
+ // NaN. If so, emit a library call to compute the multiplication instead.
237
+ // We check a value against NaN by comparing the value against itself.
238
+ mlir::Value resultRealIsNaN = builder.createIsNaN (loc, resultReal);
239
+ mlir::Value resultImagIsNaN = builder.createIsNaN (loc, resultImag);
240
+ mlir::Value resultRealAndImagAreNaN =
241
+ builder.createLogicalAnd (loc, resultRealIsNaN, resultImagIsNaN);
242
+
243
+ return builder
244
+ .create <cir::TernaryOp>(
245
+ loc, resultRealAndImagAreNaN,
246
+ [&](mlir::OpBuilder &, mlir::Location) {
247
+ mlir::Value libCallResult = buildComplexBinOpLibCall (
248
+ pass, builder, &getComplexMulLibCallName, loc, complexTy,
249
+ lhsReal, lhsImag, rhsReal, rhsImag);
250
+ builder.createYield (loc, libCallResult);
251
+ },
252
+ [&](mlir::OpBuilder &, mlir::Location) {
253
+ builder.createYield (loc, algebraicResult);
254
+ })
255
+ .getResult ();
256
+ }
257
+
258
+ void LoweringPreparePass::lowerComplexMulOp (cir::ComplexMulOp op) {
259
+ cir::CIRBaseBuilderTy builder (getContext ());
260
+ builder.setInsertionPointAfter (op);
261
+ mlir::Location loc = op.getLoc ();
262
+ mlir::TypedValue<cir::ComplexType> lhs = op.getLhs ();
263
+ mlir::TypedValue<cir::ComplexType> rhs = op.getRhs ();
264
+ mlir::Value lhsReal = builder.createComplexReal (loc, lhs);
265
+ mlir::Value lhsImag = builder.createComplexImag (loc, lhs);
266
+ mlir::Value rhsReal = builder.createComplexReal (loc, rhs);
267
+ mlir::Value rhsImag = builder.createComplexImag (loc, rhs);
268
+ mlir::Value loweredResult = lowerComplexMul (*this , builder, loc, op, lhsReal,
269
+ lhsImag, rhsReal, rhsImag);
270
+ op.replaceAllUsesWith (loweredResult);
271
+ op.erase ();
272
+ }
273
+
131
274
void LoweringPreparePass::lowerUnaryOp (cir::UnaryOp op) {
132
275
mlir::Type ty = op.getType ();
133
276
if (!mlir::isa<cir::ComplexType>(ty))
@@ -269,18 +412,22 @@ void LoweringPreparePass::runOnOp(mlir::Operation *op) {
269
412
lowerArrayDtor (arrayDtor);
270
413
else if (auto cast = mlir::dyn_cast<cir::CastOp>(op))
271
414
lowerCastOp (cast);
415
+ else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op))
416
+ lowerComplexMulOp (complexMul);
272
417
else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op))
273
418
lowerUnaryOp (unary);
274
419
}
275
420
276
421
void LoweringPreparePass::runOnOperation () {
277
422
mlir::Operation *op = getOperation ();
423
+ if (isa<::mlir::ModuleOp>(op))
424
+ mlirModule = cast<::mlir::ModuleOp>(op);
278
425
279
426
llvm::SmallVector<mlir::Operation *> opsToTransform;
280
427
281
428
op->walk ([&](mlir::Operation *op) {
282
- if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp, cir::UnaryOp>(
283
- op))
429
+ if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
430
+ cir::ComplexMulOp, cir::UnaryOp>( op))
284
431
opsToTransform.push_back (op);
285
432
});
286
433
0 commit comments