Skip to content

Commit ce6d35e

Browse files
pifon2aGoogle-ML-Automation
authored andcommitted
[XLA:GPU][Emitters] Remove the complex.expm1 approximation.
It was upstreamed in llvm/llvm-project#115082 (review) Now we can use complex-to-standard pass. Reverts d2e313c PiperOrigin-RevId: 698191660
1 parent e9947dd commit ce6d35e

File tree

3 files changed

+3
-127
lines changed

3 files changed

+3
-127
lines changed

xla/service/gpu/fusions/transforms/lower_tensors.cc

Lines changed: 3 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15-
#include <array>
1615
#include <cassert>
1716
#include <cstdint>
1817
#include <memory>
@@ -31,13 +30,11 @@ limitations under the License.
3130
#include "llvm/Support/LogicalResult.h"
3231
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
3332
#include "mlir/Dialect/Arith/IR/Arith.h"
34-
#include "mlir/Dialect/Complex/IR/Complex.h"
3533
#include "mlir/Dialect/Func/IR/FuncOps.h"
3634
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
3735
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
3836
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
3937
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
40-
#include "mlir/Dialect/Math/IR/Math.h"
4138
#include "mlir/Dialect/SCF/IR/SCF.h"
4239
#include "mlir/Dialect/Tensor/IR/Tensor.h"
4340
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -74,26 +71,19 @@ namespace {
7471
#define GEN_PASS_DEF_LOWERTENSORSPASS
7572
#include "xla/service/gpu/fusions/transforms/passes.h.inc"
7673

77-
using llvm::APFloat;
78-
using llvm::ArrayRef;
7974
using mlir::failure;
80-
using mlir::ImplicitLocOpBuilder;
8175
using mlir::Location;
8276
using mlir::LogicalResult;
8377
using mlir::MLIRContext;
8478
using mlir::OpBuilder;
8579
using mlir::Operation;
86-
using mlir::OpRewritePattern;
8780
using mlir::success;
8881
using mlir::Type;
8982
using mlir::TypedValue;
9083
using mlir::TypeRange;
9184
using mlir::Value;
9285
using mlir::ValueRange;
9386

94-
namespace ma = ::mlir::arith;
95-
namespace mc = ::mlir::complex;
96-
namespace mm = ::mlir::math;
9787
namespace arith = ::mlir::arith;
9888
namespace scf = ::mlir::scf;
9989
namespace ml = ::mlir::LLVM;
@@ -249,7 +239,7 @@ struct RewriteTensorExtract : mlir::OpRewritePattern<mlir::tensor::ExtractOp> {
249239
load, b.create<mlir::arith::ConstantIntOp>(4, load.getType()));
250240
load = b.create<mlir::arith::TruncIOp>(
251241
op.getType(),
252-
b.create<ma::SelectOp>(is_low_nibble, load, high_value));
242+
b.create<mlir::arith::SelectOp>(is_low_nibble, load, high_value));
253243
}
254244

255245
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
@@ -378,7 +368,7 @@ struct RewriteTensorInsert : mlir::OpRewritePattern<mlir::tensor::InsertOp> {
378368
body_builder.create<mlir::arith::ShLIOp>(
379369
scalar_value,
380370
body_builder.create<mlir::arith::ConstantIntOp>(4, ty)));
381-
Value new_value = body_builder.create<ma::SelectOp>(
371+
Value new_value = body_builder.create<mlir::arith::SelectOp>(
382372
is_low_nibble, low_updated, high_updated);
383373
body_builder.create<mlir::scf::YieldOp>(new_value);
384374
Value casted_result = b.create<mlir::UnrealizedConversionCastOp>(
@@ -1052,106 +1042,6 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern<AtomicRMWOp> {
10521042
std::string gpu_arch_;
10531043
};
10541044

1055-
template <typename FType>
1056-
Value EvaluatePolynomial(ImplicitLocOpBuilder& b, Value arg,
1057-
ArrayRef<FType> coefficients) {
1058-
auto arg_type = mlir::cast<mlir::FloatType>(arg.getType());
1059-
Value poly =
1060-
b.create<ma::ConstantOp>(b.getFloatAttr(arg_type, coefficients[0]));
1061-
for (int i = 1; i < coefficients.size(); ++i) {
1062-
poly = b.create<mm::FmaOp>(
1063-
poly, arg,
1064-
b.create<ma::ConstantOp>(b.getFloatAttr(arg_type, coefficients[i])));
1065-
}
1066-
return poly;
1067-
};
1068-
1069-
struct RewriterExpm1Op : public OpRewritePattern<mc::Expm1Op> {
1070-
using OpRewritePattern<mc::Expm1Op>::OpRewritePattern;
1071-
1072-
// e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
1073-
// [handle inaccuracies when a and/or b are small]
1074-
// = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
1075-
// = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
1076-
mlir::LogicalResult matchAndRewrite(
1077-
mc::Expm1Op op, mlir::PatternRewriter& rewriter) const override {
1078-
auto type = op.getType();
1079-
auto element_type = mlir::cast<mlir::FloatType>(type.getElementType());
1080-
1081-
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1082-
1083-
Value real = b.create<mc::ReOp>(op.getComplex());
1084-
Value imag = b.create<mc::ImOp>(op.getComplex());
1085-
1086-
Value zero = b.create<ma::ConstantOp>(b.getFloatAttr(element_type, 0.0));
1087-
Value one = b.create<ma::ConstantOp>(b.getFloatAttr(element_type, 1.0));
1088-
1089-
Value expm1_real = b.create<mm::ExpM1Op>(real);
1090-
Value exp_real = b.create<ma::AddFOp>(expm1_real, one);
1091-
1092-
Value sin_imag = b.create<mm::SinOp>(imag);
1093-
Value cosm1_imag = EmitCosm1(imag, b);
1094-
Value cos_imag = b.create<ma::AddFOp>(cosm1_imag, one);
1095-
1096-
Value real_result = b.create<ma::AddFOp>(
1097-
b.create<ma::MulFOp>(expm1_real, cos_imag), cosm1_imag);
1098-
1099-
Value imag_is_zero =
1100-
b.create<ma::CmpFOp>(ma::CmpFPredicate::OEQ, imag, zero);
1101-
Value imag_result = b.create<ma::SelectOp>(
1102-
imag_is_zero, zero, b.create<ma::MulFOp>(exp_real, sin_imag));
1103-
1104-
rewriter.replaceOpWithNewOp<mc::CreateOp>(op, type, real_result,
1105-
imag_result);
1106-
return mlir::success();
1107-
}
1108-
1109-
private:
1110-
Value EmitCosm1(Value arg, ImplicitLocOpBuilder& b) const {
1111-
auto arg_type = mlir::cast<mlir::FloatType>(arg.getType());
1112-
auto negative_half =
1113-
b.create<ma::ConstantOp>(b.getFloatAttr(arg_type, -0.5));
1114-
auto negative_one =
1115-
b.create<ma::ConstantOp>(b.getFloatAttr(arg_type, -1.0));
1116-
1117-
// Algorithm copied from cephes cosm1:
1118-
// cosm1(x) = -0.5 * x^2 + x^4 * P(x^2);
1119-
// that is suitable when abs(x) < pi/4, otherwise we'll use cos(x)-1.
1120-
//
1121-
// This is an alternative algorithm
1122-
// cosm1(x) = -2 * sin(x/2)^2
1123-
// that is only slightly less accurate around abs(x) == 0.1 but
1124-
// otherwise equivalent accuracy-wise compared to cephes cosm1.
1125-
// However, we are not using it because it is notably less
1126-
// performant than cephes cosm1.
1127-
1128-
// TODO: define cosm1(x) as cosm1(x mod (2*pi)) to increase accuracy
1129-
// for large x values that are close to 2*pi*n where n is some integer.
1130-
static const std::array<double, 7> kCoeffs{
1131-
4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
1132-
2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
1133-
2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
1134-
4.1666666666666666609054E-2,
1135-
};
1136-
Value cos = b.create<mm::CosOp>(arg);
1137-
Value for_large_x = b.create<ma::AddFOp>(cos, negative_one);
1138-
1139-
Value arg_pow_2 = b.create<ma::MulFOp>(arg, arg);
1140-
Value arg_pow_4 = b.create<ma::MulFOp>(arg_pow_2, arg_pow_2);
1141-
Value poly = EvaluatePolynomial(b, arg_pow_2, ArrayRef<double>(kCoeffs));
1142-
1143-
auto for_small_x =
1144-
b.create<ma::AddFOp>(b.create<ma::MulFOp>(arg_pow_4, poly),
1145-
b.create<ma::MulFOp>(negative_half, arg_pow_2));
1146-
1147-
// (pi/4)^2 is approximately 0.61685
1148-
Value cond = b.create<ma::CmpFOp>(
1149-
ma::CmpFPredicate::OGE, arg_pow_2,
1150-
b.create<ma::ConstantOp>(b.getFloatAttr(arg_type, 0.61685)));
1151-
return b.create<ma::SelectOp>(cond, for_large_x, for_small_x);
1152-
}
1153-
};
1154-
11551045
class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
11561046
public:
11571047
explicit LowerTensorsPass(const LowerTensorsPassOptions& options)
@@ -1162,7 +1052,7 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
11621052
mlir::RewritePatternSet tensor_patterns(mlir_context);
11631053
tensor_patterns.add<RewriteAtomicRMW>(mlir_context, is_amd_gpu_, gpu_arch_);
11641054
tensor_patterns
1165-
.add<RewriteAllocateShared, RewriterExpm1Op, RewriteNonScalarConstants,
1055+
.add<RewriteAllocateShared, RewriteNonScalarConstants,
11661056
RewriteSyncThreads, RewriteTensorExtract, RewriteTransferRead,
11671057
RewriteTensorInsert, RewriteTransferWrite>(mlir_context);
11681058
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(

xla/service/gpu/fusions/transforms/passes.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,8 @@ def LowerTensorsPass : Pass<"xla-gpu-lower-tensors", "mlir::ModuleOp"> {
7474

7575
let dependentDialects = [
7676
"mlir::LLVM::LLVMDialect",
77-
"mlir::complex::ComplexDialect",
7877
"mlir::func::FuncDialect",
7978
"mlir::gpu::GPUDialect",
80-
"mlir::math::MathDialect",
8179
"mlir::scf::SCFDialect",
8280
"mlir::tensor::TensorDialect",
8381
"xla::gpu::XlaGpuDialect",

xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -732,15 +732,3 @@ func.func @int4_constant(%arg0: tensor<3xi4>, %arg1: index) -> i4 {
732732
// CHECK: llvm.mlir.global private constant
733733
// CHECK-SAME: dense<[18, 48]>
734734
// CHECK-LABEL: @int4_constant
735-
736-
// -----
737-
738-
func.func @complex_expm1_approx(%arg0: tensor<3xcomplex<f32>>, %i: index)
739-
-> complex<f32> {
740-
%extracted = tensor.extract %arg0[%i] : tensor<3xcomplex<f32>>
741-
%expm1 = complex.expm1 %extracted : complex<f32>
742-
return %expm1 : complex<f32>
743-
}
744-
// CHECK-LABEL: @complex_expm1_approx
745-
// CHECK: math.expm1
746-
// CHECK-COUNT-6: math.fma

0 commit comments

Comments
 (0)