Skip to content

Commit 1935f84

Browse files
[mlir][complex] Add complex-range option and select complex division … (#127010)
…algorithm This patch adds the `complex-range` option and two calculation methods for complex number division (algebraic method and Smith's algorithm) to both the `ComplexToLLVM` and `ComplexToStandard` passes, allowing the calculation method to be controlled by the option. See also the discussion in the following discourse post. https://discourse.llvm.org/t/question-and-proposal-regarding-complex-number-division-algorithm-in-the-complex-dialect/83772
1 parent 6fb1d40 commit 1935f84

File tree

18 files changed

+1223
-232
lines changed

18 files changed

+1223
-232
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//===- DivisionConverter.h - Complex division conversion ------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_COMPLEXCOMMON_DIVISIONCONVERTER_H
10+
#define MLIR_CONVERSION_COMPLEXCOMMON_DIVISIONCONVERTER_H
11+
12+
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13+
#include "mlir/Dialect/Arith/IR/Arith.h"
14+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15+
16+
namespace mlir {
17+
namespace complex {
18+
/// convert a complex division to the LLVM dialect using algebraic method
19+
void convertDivToLLVMUsingAlgebraic(ConversionPatternRewriter &rewriter,
20+
Location loc, Value lhsRe, Value lhsIm,
21+
Value rhsRe, Value rhsIm,
22+
LLVM::FastmathFlagsAttr fmf,
23+
Value *resultRe, Value *resultIm);
24+
25+
/// convert a complex division to the arith/math dialects using algebraic method
26+
void convertDivToStandardUsingAlgebraic(ConversionPatternRewriter &rewriter,
27+
Location loc, Value lhsRe, Value lhsIm,
28+
Value rhsRe, Value rhsIm,
29+
arith::FastMathFlagsAttr fmf,
30+
Value *resultRe, Value *resultIm);
31+
32+
/// convert a complex division to the LLVM dialect using Smith's method
33+
void convertDivToLLVMUsingRangeReduction(ConversionPatternRewriter &rewriter,
34+
Location loc, Value lhsRe, Value lhsIm,
35+
Value rhsRe, Value rhsIm,
36+
LLVM::FastmathFlagsAttr fmf,
37+
Value *resultRe, Value *resultIm);
38+
39+
/// convert a complex division to the arith/math dialects using Smith's method
40+
void convertDivToStandardUsingRangeReduction(
41+
ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm,
42+
Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe,
43+
Value *resultIm);
44+
45+
} // namespace complex
46+
} // namespace mlir
47+
48+
#endif // MLIR_CONVERSION_COMPLEXCOMMON_DIVISIONCONVERTER_H

mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#define MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_
1010

1111
#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
12+
#include "mlir/Dialect/Complex/IR/Complex.h"
13+
#include "mlir/Pass/Pass.h"
1214

1315
namespace mlir {
1416
class DialectRegistry;
@@ -39,8 +41,10 @@ class ComplexStructBuilder : public StructBuilder {
3941
};
4042

4143
/// Populate the given list with patterns that convert from Complex to LLVM.
42-
void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter,
43-
RewritePatternSet &patterns);
44+
void populateComplexToLLVMConversionPatterns(
45+
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
46+
mlir::complex::ComplexRangeFlags complexRange =
47+
mlir::complex::ComplexRangeFlags::basic);
4448

4549
void registerConvertComplexToLLVMInterface(DialectRegistry &registry);
4650

mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#ifndef MLIR_CONVERSION_COMPLEXTOSTANDARD_COMPLEXTOSTANDARD_H_
99
#define MLIR_CONVERSION_COMPLEXTOSTANDARD_COMPLEXTOSTANDARD_H_
1010

11+
#include "mlir/Dialect/Complex/IR/Complex.h"
12+
#include "mlir/Pass/Pass.h"
1113
#include <memory>
1214

1315
namespace mlir {
@@ -18,10 +20,15 @@ class Pass;
1820
#include "mlir/Conversion/Passes.h.inc"
1921

2022
/// Populate the given list with patterns that convert from Complex to Standard.
21-
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns);
23+
void populateComplexToStandardConversionPatterns(
24+
RewritePatternSet &patterns,
25+
mlir::complex::ComplexRangeFlags complexRange =
26+
mlir::complex::ComplexRangeFlags::improved);
2227

2328
/// Create a pass to convert Complex operations to the Standard dialect.
2429
std::unique_ptr<Pass> createConvertComplexToStandardPass();
30+
std::unique_ptr<Pass>
31+
createConvertComplexToStandardPass(ConvertComplexToStandardOptions options);
2532

2633
} // namespace mlir
2734

mlir/include/mlir/Conversion/Passes.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,17 @@ def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> {
274274
def ConvertComplexToLLVMPass : Pass<"convert-complex-to-llvm"> {
275275
let summary = "Convert Complex dialect to LLVM dialect";
276276
let dependentDialects = ["LLVM::LLVMDialect"];
277+
278+
let options = [
279+
Option<"complexRange", "complex-range", "::mlir::complex::ComplexRangeFlags",
280+
/*default=*/"::mlir::complex::ComplexRangeFlags::basic",
281+
"Control the intermediate calculation of complex number division",
282+
[{::llvm::cl::values(
283+
clEnumValN(::mlir::complex::ComplexRangeFlags::improved, "improved", "improved"),
284+
clEnumValN(::mlir::complex::ComplexRangeFlags::basic, "basic", "basic (default)"),
285+
clEnumValN(::mlir::complex::ComplexRangeFlags::none, "none", "none")
286+
)}]>,
287+
];
277288
}
278289

279290
//===----------------------------------------------------------------------===//
@@ -308,6 +319,17 @@ def ConvertComplexToStandard : Pass<"convert-complex-to-standard"> {
308319
let summary = "Convert Complex dialect to standard dialect";
309320
let constructor = "mlir::createConvertComplexToStandardPass()";
310321
let dependentDialects = ["math::MathDialect"];
322+
323+
let options = [
324+
Option<"complexRange", "complex-range", "::mlir::complex::ComplexRangeFlags",
325+
/*default=*/"::mlir::complex::ComplexRangeFlags::improved",
326+
"Control the intermediate calculation of complex number division",
327+
[{::llvm::cl::values(
328+
clEnumValN(::mlir::complex::ComplexRangeFlags::improved, "improved", "improved (default)"),
329+
clEnumValN(::mlir::complex::ComplexRangeFlags::basic, "basic", "basic"),
330+
clEnumValN(::mlir::complex::ComplexRangeFlags::none, "none", "none")
331+
)}]>,
332+
];
311333
}
312334

313335
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ add_mlir_dialect(ComplexOps complex)
22
add_mlir_doc(ComplexOps ComplexOps Dialects/ -gen-dialect-doc -dialect=complex)
33

44
set(LLVM_TARGET_DEFINITIONS ComplexAttributes.td)
5+
mlir_tablegen(ComplexEnums.h.inc -gen-enum-decls)
6+
mlir_tablegen(ComplexEnums.cpp.inc -gen-enum-defs)
57
mlir_tablegen(ComplexAttributes.h.inc -gen-attrdef-decls)
68
mlir_tablegen(ComplexAttributes.cpp.inc -gen-attrdef-defs)
79
add_public_tablegen_target(MLIRComplexAttributesIncGen)

mlir/include/mlir/Dialect/Complex/IR/Complex.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222

2323
#include "mlir/Dialect/Complex/IR/ComplexOpsDialect.h.inc"
2424

25+
//===----------------------------------------------------------------------===//
26+
// Complex Dialect Enums
27+
//===----------------------------------------------------------------------===//
28+
29+
#include "mlir/Dialect/Complex/IR/ComplexEnums.h.inc"
30+
2531
//===----------------------------------------------------------------------===//
2632
// Complex Dialect Operations
2733
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef COMPLEX_BASE
1010
#define COMPLEX_BASE
1111

12+
include "mlir/IR/EnumAttr.td"
1213
include "mlir/IR/OpBase.td"
1314

1415
def Complex_Dialect : Dialect {
@@ -24,4 +25,19 @@ def Complex_Dialect : Dialect {
2425
let useDefaultAttributePrinterParser = 1;
2526
}
2627

28+
//===----------------------------------------------------------------------===//
29+
// Complex_ComplexRangeFlags
30+
//===----------------------------------------------------------------------===//
31+
32+
def Complex_CRF_improved : I32BitEnumAttrCaseBit<"improved", 0>;
33+
def Complex_CRF_basic : I32BitEnumAttrCaseBit<"basic", 1>;
34+
def Complex_CRF_none : I32BitEnumAttrCaseBit<"none", 2>;
35+
36+
def Complex_ComplexRangeFlags : I32BitEnumAttr<
37+
"ComplexRangeFlags",
38+
"Complex range flags",
39+
[Complex_CRF_improved, Complex_CRF_basic, Complex_CRF_none]> {
40+
let cppNamespace = "::mlir::complex";
41+
}
42+
2743
#endif // COMPLEX_BASE

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_subdirectory(ArmSMEToSCF)
1111
add_subdirectory(ArmSMEToLLVM)
1212
add_subdirectory(AsyncToLLVM)
1313
add_subdirectory(BufferizationToMemRef)
14+
add_subdirectory(ComplexCommon)
1415
add_subdirectory(ComplexToLibm)
1516
add_subdirectory(ComplexToLLVM)
1617
add_subdirectory(ComplexToSPIRV)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
add_mlir_conversion_library(MLIRComplexDivisionConversion
2+
DivisionConverter.cpp
3+
4+
LINK_COMPONENTS
5+
Core
6+
7+
LINK_LIBS PUBLIC
8+
MLIRArithDialect
9+
MLIRComplexDialect
10+
MLIRLLVMDialect
11+
MLIRMathDialect
12+
)

0 commit comments

Comments
 (0)