Skip to content

Conversation

@s-watanabe314
Copy link
Contributor

…algorithm

This patch adds the complex-range option and two calculation methods for complex number division (algebraic method and Smith's algorithm) to both the ComplexToLLVM and ComplexToStandard passes, allowing the calculation method to be controlled by the option.

See also the discussion in the following discourse post.

https://discourse.llvm.org/t/question-and-proposal-regarding-complex-number-division-algorithm-in-the-complex-dialect/83772

…algorithm

This patch adds the `complex-range` option and two calculation methods
for complex number division (algebraic method and Smith's algorithm)
to both the `ComplexToLLVM` and `ComplexToStandard` passes, allowing
the calculation method to be controlled by the option.

See also the discussion in the following discourse post.

https://discourse.llvm.org/t/question-and-proposal-regarding-complex-number-division-algorithm-in-the-complex-dialect/83772
@llvmbot llvmbot added mlir mlir:complex MLIR complex dialect labels Feb 13, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 13, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-complex

Author: None (s-watanabe314)

Changes

…algorithm

This patch adds the complex-range option and two calculation methods for complex number division (algebraic method and Smith's algorithm) to both the ComplexToLLVM and ComplexToStandard passes, allowing the calculation method to be controlled by the option.

See also the discussion in the following discourse post.

https://discourse.llvm.org/t/question-and-proposal-regarding-complex-number-division-algorithm-in-the-complex-dialect/83772


Patch is 110.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127010.diff

18 Files Affected:

  • (added) mlir/include/mlir/Conversion/ComplexCommon/DivisionConverter.h (+48)
  • (modified) mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h (+6-2)
  • (modified) mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h (+8-1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+22)
  • (modified) mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt (+2)
  • (modified) mlir/include/mlir/Dialect/Complex/IR/Complex.h (+6)
  • (modified) mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td (+16)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/ComplexCommon/CMakeLists.txt (+12)
  • (added) mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp (+456)
  • (modified) mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt (+1)
  • (modified) mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp (+25-19)
  • (modified) mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt (+1)
  • (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+36-207)
  • (added) mlir/test/Conversion/ComplexToLLVM/complex-range-option.mlir (+303)
  • (modified) mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir (+2-2)
  • (modified) mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir (+1-1)
  • (added) mlir/test/Conversion/ComplexToStandard/complex-range-option.mlir (+277)
diff --git a/mlir/include/mlir/Conversion/ComplexCommon/DivisionConverter.h b/mlir/include/mlir/Conversion/ComplexCommon/DivisionConverter.h
new file mode 100644
index 0000000000000..df97dc2c4eb7d
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ComplexCommon/DivisionConverter.h
@@ -0,0 +1,48 @@
+//===- DivisionConverter.h - Complex division conversion ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_COMPLEXCOMMON_DIVISIONCONVERTER_H
+#define MLIR_CONVERSION_COMPLEXCOMMON_DIVISIONCONVERTER_H
+
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+
+namespace mlir {
+namespace complex {
+/// convert a complex division to the LLVM dialect using algebraic method
+void convertDivToLLVMUsingAlgebraic(ConversionPatternRewriter &rewriter,
+                                    Location loc, Value lhsRe, Value lhsIm,
+                                    Value rhsRe, Value rhsIm,
+                                    LLVM::FastmathFlagsAttr fmf,
+                                    Value *resultRe, Value *resultIm);
+
+/// convert a complex division to the arith/math dialects using algebraic method
+void convertDivToStandardUsingAlgebraic(ConversionPatternRewriter &rewriter,
+                                        Location loc, Value lhsRe, Value lhsIm,
+                                        Value rhsRe, Value rhsIm,
+                                        arith::FastMathFlagsAttr fmf,
+                                        Value *resultRe, Value *resultIm);
+
+/// convert a complex division to the LLVM dialect using Smith's method
+void convertDivToLLVMUsingRangeReduction(ConversionPatternRewriter &rewriter,
+                                         Location loc, Value lhsRe, Value lhsIm,
+                                         Value rhsRe, Value rhsIm,
+                                         LLVM::FastmathFlagsAttr fmf,
+                                         Value *resultRe, Value *resultIm);
+
+/// convert a complex division to the arith/math dialects using Smith's method
+void convertDivToStandardUsingRangeReduction(
+    ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm,
+    Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe,
+    Value *resultIm);
+
+} // namespace complex
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_COMPLEXCOMMON_DIVISIONCONVERTER_H
diff --git a/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h
index 8266442cf5db8..1db75563fe304 100644
--- a/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h
+++ b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h
@@ -9,6 +9,8 @@
 #define MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_
 
 #include "mlir/Conversion/LLVMCommon/StructBuilder.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Pass/Pass.h"
 
 namespace mlir {
 class DialectRegistry;
@@ -39,8 +41,10 @@ class ComplexStructBuilder : public StructBuilder {
 };
 
 /// Populate the given list with patterns that convert from Complex to LLVM.
-void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter,
-                                             RewritePatternSet &patterns);
+void populateComplexToLLVMConversionPatterns(
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    mlir::complex::ComplexRangeFlags complexRange =
+        mlir::complex::ComplexRangeFlags::basic);
 
 void registerConvertComplexToLLVMInterface(DialectRegistry &registry);
 
diff --git a/mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h b/mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h
index 39c4a1ae54617..30b86cac9cd4e 100644
--- a/mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h
+++ b/mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h
@@ -8,6 +8,8 @@
 #ifndef MLIR_CONVERSION_COMPLEXTOSTANDARD_COMPLEXTOSTANDARD_H_
 #define MLIR_CONVERSION_COMPLEXTOSTANDARD_COMPLEXTOSTANDARD_H_
 
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Pass/Pass.h"
 #include <memory>
 
 namespace mlir {
@@ -18,10 +20,15 @@ class Pass;
 #include "mlir/Conversion/Passes.h.inc"
 
 /// Populate the given list with patterns that convert from Complex to Standard.
-void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns);
+void populateComplexToStandardConversionPatterns(
+    RewritePatternSet &patterns,
+    mlir::complex::ComplexRangeFlags complexRange =
+        mlir::complex::ComplexRangeFlags::improved);
 
 /// Create a pass to convert Complex operations to the Standard dialect.
 std::unique_ptr<Pass> createConvertComplexToStandardPass();
+std::unique_ptr<Pass>
+createConvertComplexToStandardPass(ConvertComplexToStandardOptions options);
 
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index ff79a1226c047..5203838a6eb35 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -274,6 +274,17 @@ def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> {
 def ConvertComplexToLLVMPass : Pass<"convert-complex-to-llvm"> {
   let summary = "Convert Complex dialect to LLVM dialect";
   let dependentDialects = ["LLVM::LLVMDialect"];
+
+  let options = [
+    Option<"complexRange", "complex-range", "::mlir::complex::ComplexRangeFlags",
+      /*default=*/"::mlir::complex::ComplexRangeFlags::basic",
+      "Control the intermediate calculation of complex number division",
+      [{::llvm::cl::values(
+        clEnumValN(::mlir::complex::ComplexRangeFlags::improved, "improved", "improved"),
+        clEnumValN(::mlir::complex::ComplexRangeFlags::basic, "basic", "basic (default)"),
+        clEnumValN(::mlir::complex::ComplexRangeFlags::none, "none", "none")
+      )}]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -308,6 +319,17 @@ def ConvertComplexToStandard : Pass<"convert-complex-to-standard"> {
   let summary = "Convert Complex dialect to standard dialect";
   let constructor = "mlir::createConvertComplexToStandardPass()";
   let dependentDialects = ["math::MathDialect"];
+
+  let options = [
+    Option<"complexRange", "complex-range", "::mlir::complex::ComplexRangeFlags",
+      /*default=*/"::mlir::complex::ComplexRangeFlags::improved",
+      "Control the intermediate calculation of complex number division",
+      [{::llvm::cl::values(
+        clEnumValN(::mlir::complex::ComplexRangeFlags::improved, "improved", "improved (default)"),
+        clEnumValN(::mlir::complex::ComplexRangeFlags::basic, "basic", "basic"),
+        clEnumValN(::mlir::complex::ComplexRangeFlags::none, "none", "none")
+      )}]>,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt
index f41888d01a2fd..837664e25b3c2 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt
@@ -2,6 +2,8 @@ add_mlir_dialect(ComplexOps complex)
 add_mlir_doc(ComplexOps ComplexOps Dialects/ -gen-dialect-doc  -dialect=complex)
 
 set(LLVM_TARGET_DEFINITIONS ComplexAttributes.td)
+mlir_tablegen(ComplexEnums.h.inc -gen-enum-decls)
+mlir_tablegen(ComplexEnums.cpp.inc -gen-enum-defs)
 mlir_tablegen(ComplexAttributes.h.inc -gen-attrdef-decls)
 mlir_tablegen(ComplexAttributes.cpp.inc -gen-attrdef-defs)
 add_public_tablegen_target(MLIRComplexAttributesIncGen)
diff --git a/mlir/include/mlir/Dialect/Complex/IR/Complex.h b/mlir/include/mlir/Dialect/Complex/IR/Complex.h
index fb024fa2e951e..be7e50d656385 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/Complex.h
+++ b/mlir/include/mlir/Dialect/Complex/IR/Complex.h
@@ -22,6 +22,12 @@
 
 #include "mlir/Dialect/Complex/IR/ComplexOpsDialect.h.inc"
 
+//===----------------------------------------------------------------------===//
+// Complex Dialect Enums
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Complex/IR/ComplexEnums.h.inc"
+
 //===----------------------------------------------------------------------===//
 // Complex Dialect Operations
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
index 31135fc8c8ce7..c8af498f44829 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
@@ -9,6 +9,7 @@
 #ifndef COMPLEX_BASE
 #define COMPLEX_BASE
 
+include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 
 def Complex_Dialect : Dialect {
@@ -24,4 +25,19 @@ def Complex_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Complex_ComplexRangeFlags
+//===----------------------------------------------------------------------===//
+
+def Complex_CRF_improved  : I32BitEnumAttrCaseBit<"improved", 0>;
+def Complex_CRF_basic : I32BitEnumAttrCaseBit<"basic", 1>;
+def Complex_CRF_none  : I32BitEnumAttrCaseBit<"none", 2>;
+
+def Complex_ComplexRangeFlags : I32BitEnumAttr<
+    "ComplexRangeFlags",
+    "Complex range flags",
+    [Complex_CRF_improved, Complex_CRF_basic, Complex_CRF_none]> {
+  let cppNamespace = "::mlir::complex";
+}
+
 #endif // COMPLEX_BASE
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 0bd08ec6333e6..fa904a33ebf96 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -11,6 +11,7 @@ add_subdirectory(ArmSMEToSCF)
 add_subdirectory(ArmSMEToLLVM)
 add_subdirectory(AsyncToLLVM)
 add_subdirectory(BufferizationToMemRef)
+add_subdirectory(ComplexCommon)
 add_subdirectory(ComplexToLibm)
 add_subdirectory(ComplexToLLVM)
 add_subdirectory(ComplexToSPIRV)
diff --git a/mlir/lib/Conversion/ComplexCommon/CMakeLists.txt b/mlir/lib/Conversion/ComplexCommon/CMakeLists.txt
new file mode 100644
index 0000000000000..2560a4a5631f4
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexCommon/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_conversion_library(MLIRComplexDivisionConversion
+  DivisionConverter.cpp
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIRComplexDialect
+  MLIRLLVMDialect
+  MLIRMathDialect
+  )
diff --git a/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp b/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp
new file mode 100644
index 0000000000000..cce9cc77c3a4c
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp
@@ -0,0 +1,456 @@
+//===- DivisionConverter.cpp - Complex division conversion ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements functions for two different complex number division
+// algorithms, the `algebraic formula` and `Smith's range reduction method`.
+// These are used in two conversions: `ComplexToLLVM` and `ComplexToStandard`.
+// When modifying the algorithms, both `ToLLVM` and `ToStandard` must be
+// changed.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ComplexCommon/DivisionConverter.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+
+using namespace mlir;
+
+void mlir::complex::convertDivToLLVMUsingAlgebraic(
+    ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm,
+    Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe,
+    Value *resultIm) {
+  Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
+      loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
+      rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
+
+  Value realNumerator = rewriter.create<LLVM::FAddOp>(
+      loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
+      rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
+
+  Value imagNumerator = rewriter.create<LLVM::FSubOp>(
+      loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
+      rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
+
+  *resultRe = rewriter.create<LLVM::FDivOp>(loc, realNumerator, rhsSqNorm, fmf);
+  *resultIm = rewriter.create<LLVM::FDivOp>(loc, imagNumerator, rhsSqNorm, fmf);
+}
+
+void mlir::complex::convertDivToStandardUsingAlgebraic(
+    ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm,
+    Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe,
+    Value *resultIm) {
+  Value rhsSqNorm = rewriter.create<arith::AddFOp>(
+      loc, rewriter.create<arith::MulFOp>(loc, rhsRe, rhsRe, fmf),
+      rewriter.create<arith::MulFOp>(loc, rhsIm, rhsIm, fmf), fmf);
+
+  Value realNumerator = rewriter.create<arith::AddFOp>(
+      loc, rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRe, fmf),
+      rewriter.create<arith::MulFOp>(loc, lhsIm, rhsIm, fmf), fmf);
+  Value imagNumerator = rewriter.create<arith::SubFOp>(
+      loc, rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRe, fmf),
+      rewriter.create<arith::MulFOp>(loc, lhsRe, rhsIm, fmf), fmf);
+
+  *resultRe =
+      rewriter.create<arith::DivFOp>(loc, realNumerator, rhsSqNorm, fmf);
+  *resultIm =
+      rewriter.create<arith::DivFOp>(loc, imagNumerator, rhsSqNorm, fmf);
+};
+
+// Smith's algorithm to divide complex numbers. It is just a bit smarter
+// way to compute the following algebraic formula:
+//  (lhsRe + lhsIm * i) / (rhsRe + rhsIm * i)
+//    = (lhsRe + lhsIm * i) (rhsRe - rhsIm * i) /
+//          ((rhsRe + rhsIm * i)(rhsRe - rhsIm * i))
+//    = ((lhsRe * rhsRe + lhsIm * rhsIm) +
+//          (lhsIm * rhsRe - lhsRe * rhsIm) * i) / ||rhs||^2
+//
+// Depending on whether |rhsRe| < |rhsIm| we compute either
+//   rhsRealImagRatio = rhsRe / rhsIm
+//   rhsRealImagDenom = rhsIm + rhsRe * rhsRealImagRatio
+//   resultRe = (lhsRe * rhsRealImagRatio + lhsIm) /
+//                  rhsRealImagDenom
+//   resultIm = (lhsIm * rhsRealImagRatio - lhsRe) /
+//                  rhsRealImagDenom
+//
+// or
+//
+//   rhsImagRealRatio = rhsIm / rhsRe
+//   rhsImagRealDenom = rhsRe + rhsIm * rhsImagRealRatio
+//   resultRe = (lhsRe + lhsIm * rhsImagRealRatio) /
+//                  rhsImagRealDenom
+//   resultIm = (lhsIm - lhsRe * rhsImagRealRatio) /
+//                  rhsImagRealDenom
+//
+// See https://dl.acm.org/citation.cfm?id=368661 for more details.
+
+void mlir::complex::convertDivToLLVMUsingRangeReduction(
+    ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm,
+    Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe,
+    Value *resultIm) {
+  auto elementType = cast<FloatType>(rhsRe.getType());
+
+  Value rhsRealImagRatio =
+      rewriter.create<LLVM::FDivOp>(loc, rhsRe, rhsIm, fmf);
+  Value rhsRealImagDenom = rewriter.create<LLVM::FAddOp>(
+      loc, rhsIm,
+      rewriter.create<LLVM::FMulOp>(loc, rhsRealImagRatio, rhsRe, fmf), fmf);
+  Value realNumerator1 = rewriter.create<LLVM::FAddOp>(
+      loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRealImagRatio, fmf),
+      lhsIm, fmf);
+  Value resultReal1 =
+      rewriter.create<LLVM::FDivOp>(loc, realNumerator1, rhsRealImagDenom, fmf);
+  Value imagNumerator1 = rewriter.create<LLVM::FSubOp>(
+      loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRealImagRatio, fmf),
+      lhsRe, fmf);
+  Value resultImag1 =
+      rewriter.create<LLVM::FDivOp>(loc, imagNumerator1, rhsRealImagDenom, fmf);
+
+  Value rhsImagRealRatio =
+      rewriter.create<LLVM::FDivOp>(loc, rhsIm, rhsRe, fmf);
+  Value rhsImagRealDenom = rewriter.create<LLVM::FAddOp>(
+      loc, rhsRe,
+      rewriter.create<LLVM::FMulOp>(loc, rhsImagRealRatio, rhsIm, fmf), fmf);
+  Value realNumerator2 = rewriter.create<LLVM::FAddOp>(
+      loc, lhsRe,
+      rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsImagRealRatio, fmf), fmf);
+  Value resultReal2 =
+      rewriter.create<LLVM::FDivOp>(loc, realNumerator2, rhsImagRealDenom, fmf);
+  Value imagNumerator2 = rewriter.create<LLVM::FSubOp>(
+      loc, lhsIm,
+      rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsImagRealRatio, fmf), fmf);
+  Value resultImag2 =
+      rewriter.create<LLVM::FDivOp>(loc, imagNumerator2, rhsImagRealDenom, fmf);
+
+  // Consider corner cases.
+  // Case 1. Zero denominator, numerator contains at most one NaN value.
+  Value zero = rewriter.create<LLVM::ConstantOp>(
+      loc, elementType, rewriter.getZeroAttr(elementType));
+  Value rhsRealAbs = rewriter.create<LLVM::FAbsOp>(loc, rhsRe, fmf);
+  Value rhsRealIsZero = rewriter.create<LLVM::FCmpOp>(
+      loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, zero);
+  Value rhsImagAbs = rewriter.create<LLVM::FAbsOp>(loc, rhsIm, fmf);
+  Value rhsImagIsZero = rewriter.create<LLVM::FCmpOp>(
+      loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, zero);
+  Value lhsRealIsNotNaN =
+      rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::ord, lhsRe, zero);
+  Value lhsImagIsNotNaN =
+      rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::ord, lhsIm, zero);
+  Value lhsContainsNotNaNValue =
+      rewriter.create<LLVM::OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
+  Value resultIsInfinity = rewriter.create<LLVM::AndOp>(
+      loc, lhsContainsNotNaNValue,
+      rewriter.create<LLVM::AndOp>(loc, rhsRealIsZero, rhsImagIsZero));
+  Value inf = rewriter.create<LLVM::ConstantOp>(
+      loc, elementType,
+      rewriter.getFloatAttr(elementType,
+                            APFloat::getInf(elementType.getFloatSemantics())));
+  Value infWithSignOfrhsReal =
+      rewriter.create<LLVM::CopySignOp>(loc, inf, rhsRe);
+  Value infinityResultReal =
+      rewriter.create<LLVM::FMulOp>(loc, infWithSignOfrhsReal, lhsRe, fmf);
+  Value infinityResultImag =
+      rewriter.create<LLVM::FMulOp>(loc, infWithSignOfrhsReal, lhsIm, fmf);
+
+  // Case 2. Infinite numerator, finite denominator.
+  Value rhsRealFinite = rewriter.create<LLVM::FCmpOp>(
+      loc, LLVM::FCmpPredicate::one, rhsRealAbs, inf);
+  Value rhsImagFinite = rewriter.create<LLVM::FCmpOp>(
+      loc, LLVM::FCmpPredicate::one, rhsImagAbs, inf);
+  Value rhsFinite =
+      rewriter.create<LLVM::AndOp>(loc, rhsRealFinite, rhsImagFinite);
+  Value lhsRealAbs = rewriter.create<LLVM::FAbsOp>(loc, lhsRe, fmf);
+  Value lhsRealInfinite = rewriter.create<LLVM::FCmpOp>(
+      loc, LLVM::FCmpPredicate::oeq, lhsRealAbs, inf);
+  Value lhsImagAbs = rewriter.create<LLVM::FAbsOp>(loc, lhsIm, fmf);
+  Value lhsImagInfinite = rewriter.create<LLVM::FCmpOp>(
+      loc, LLVM::FCmpPredicate::oeq, lhsImagAbs, inf);
+  Value lhsInfinite =
+      rewriter.create<LLVM::OrOp>(loc, lhsRealInfinite, lhsImagInfinite);
+  Value infNumFiniteDenom =
+      rewriter.create<LLVM::AndOp>(loc, lhsInfinite, rhsFinite);
+  Value one = rewriter.create<LLVM::ConstantOp>(
+      loc, elementType, rewriter.getFloatAttr(elementType, 1));
+  Value lhsRealIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
+      loc, rewriter.create<LLVM::SelectOp>(loc, lhsRealInfinite, one, zero),
+      lhsRe);
+  Value lhsImagIsInfWithSign = rewriter.create<LLVM::CopySignOp>(
+      loc, rewriter.create<LLVM::SelectOp>(loc, lhsImagInfinite, one, zero),
+      lhsIm);
+  Value lhsRealIsInfWithSignTimesrhsReal =
+      rewriter.create<LLVM::FMulOp>(loc, lhsRealIsInfWithSign, rhsRe, fmf);
+  Value lhsImagIsInfWithSignTimesrhsImag =
+      rewriter.create<LLVM::FMulOp>(loc, lhsImagIsInfWithSign, rhsIm, fmf);
+  Value resultReal3 = rewriter.create<LLVM::FMulOp>(
+      loc, inf,
+      rewriter.create<LLVM::FAddOp>(loc, lhsRealIsInfWithSignTimesrhsReal,
+                                    lhsImagIsInfWithSignTimesrhsImag, fmf),
+      fmf);
+  Value lhsRealIsInfWithSignTimesrh...
[truncated]

Copy link
Member

@akuegel akuegel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me but please wait also for others to take a look.

@pifon2a pifon2a merged commit 1935f84 into llvm:main Feb 13, 2025
11 checks passed
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
llvm#127010)

…algorithm

This patch adds the `complex-range` option and two calculation methods
for complex number division (algebraic method and Smith's algorithm) to
both the `ComplexToLLVM` and `ComplexToStandard` passes, allowing the
calculation method to be controlled by the option.

See also the discussion in the following discourse post.


https://discourse.llvm.org/t/question-and-proposal-regarding-complex-number-division-algorithm-in-the-complex-dialect/83772
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
llvm#127010)

…algorithm

This patch adds the `complex-range` option and two calculation methods
for complex number division (algebraic method and Smith's algorithm) to
both the `ComplexToLLVM` and `ComplexToStandard` passes, allowing the
calculation method to be controlled by the option.

See also the discussion in the following discourse post.


https://discourse.llvm.org/t/question-and-proposal-regarding-complex-number-division-algorithm-in-the-complex-dialect/83772
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:complex MLIR complex dialect mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants