@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
15- #include < array>
1615#include < cassert>
1716#include < cstdint>
1817#include < memory>
@@ -31,13 +30,11 @@ limitations under the License.
3130#include " llvm/Support/LogicalResult.h"
3231#include " mlir/Conversion/LLVMCommon/TypeConverter.h"
3332#include " mlir/Dialect/Arith/IR/Arith.h"
34- #include " mlir/Dialect/Complex/IR/Complex.h"
3533#include " mlir/Dialect/Func/IR/FuncOps.h"
3634#include " mlir/Dialect/GPU/IR/GPUDialect.h"
3735#include " mlir/Dialect/LLVMIR/LLVMAttrs.h"
3836#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
3937#include " mlir/Dialect/LLVMIR/LLVMTypes.h"
40- #include " mlir/Dialect/Math/IR/Math.h"
4138#include " mlir/Dialect/SCF/IR/SCF.h"
4239#include " mlir/Dialect/Tensor/IR/Tensor.h"
4340#include " mlir/Dialect/Vector/IR/VectorOps.h"
@@ -74,26 +71,19 @@ namespace {
7471#define GEN_PASS_DEF_LOWERTENSORSPASS
7572#include " xla/service/gpu/fusions/transforms/passes.h.inc"
7673
77- using llvm::APFloat;
78- using llvm::ArrayRef;
7974using mlir::failure;
80- using mlir::ImplicitLocOpBuilder;
8175using mlir::Location;
8276using mlir::LogicalResult;
8377using mlir::MLIRContext;
8478using mlir::OpBuilder;
8579using mlir::Operation;
86- using mlir::OpRewritePattern;
8780using mlir::success;
8881using mlir::Type;
8982using mlir::TypedValue;
9083using mlir::TypeRange;
9184using mlir::Value;
9285using mlir::ValueRange;
9386
94- namespace ma = ::mlir::arith;
95- namespace mc = ::mlir::complex ;
96- namespace mm = ::mlir::math;
9787namespace arith = ::mlir::arith;
9888namespace scf = ::mlir::scf;
9989namespace ml = ::mlir::LLVM;
@@ -249,7 +239,7 @@ struct RewriteTensorExtract : mlir::OpRewritePattern<mlir::tensor::ExtractOp> {
249239 load, b.create <mlir::arith::ConstantIntOp>(4 , load.getType ()));
250240 load = b.create <mlir::arith::TruncIOp>(
251241 op.getType (),
252- b.create <ma ::SelectOp>(is_low_nibble, load, high_value));
242+ b.create <mlir::arith ::SelectOp>(is_low_nibble, load, high_value));
253243 }
254244
255245 rewriter.replaceOpWithNewOp <mlir::UnrealizedConversionCastOp>(
@@ -378,7 +368,7 @@ struct RewriteTensorInsert : mlir::OpRewritePattern<mlir::tensor::InsertOp> {
378368 body_builder.create <mlir::arith::ShLIOp>(
379369 scalar_value,
380370 body_builder.create <mlir::arith::ConstantIntOp>(4 , ty)));
381- Value new_value = body_builder.create <ma ::SelectOp>(
371+ Value new_value = body_builder.create <mlir::arith ::SelectOp>(
382372 is_low_nibble, low_updated, high_updated);
383373 body_builder.create <mlir::scf::YieldOp>(new_value);
384374 Value casted_result = b.create <mlir::UnrealizedConversionCastOp>(
@@ -1052,106 +1042,6 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern<AtomicRMWOp> {
10521042 std::string gpu_arch_;
10531043};
10541044
1055- template <typename FType>
1056- Value EvaluatePolynomial (ImplicitLocOpBuilder& b, Value arg,
1057- ArrayRef<FType> coefficients) {
1058- auto arg_type = mlir::cast<mlir::FloatType>(arg.getType ());
1059- Value poly =
1060- b.create <ma::ConstantOp>(b.getFloatAttr (arg_type, coefficients[0 ]));
1061- for (int i = 1 ; i < coefficients.size (); ++i) {
1062- poly = b.create <mm::FmaOp>(
1063- poly, arg,
1064- b.create <ma::ConstantOp>(b.getFloatAttr (arg_type, coefficients[i])));
1065- }
1066- return poly;
1067- };
1068-
1069- struct RewriterExpm1Op : public OpRewritePattern <mc::Expm1Op> {
1070- using OpRewritePattern<mc::Expm1Op>::OpRewritePattern;
1071-
1072- // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
1073- // [handle inaccuracies when a and/or b are small]
1074- // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
1075- // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
1076- mlir::LogicalResult matchAndRewrite (
1077- mc::Expm1Op op, mlir::PatternRewriter& rewriter) const override {
1078- auto type = op.getType ();
1079- auto element_type = mlir::cast<mlir::FloatType>(type.getElementType ());
1080-
1081- ImplicitLocOpBuilder b (op.getLoc (), rewriter);
1082-
1083- Value real = b.create <mc::ReOp>(op.getComplex ());
1084- Value imag = b.create <mc::ImOp>(op.getComplex ());
1085-
1086- Value zero = b.create <ma::ConstantOp>(b.getFloatAttr (element_type, 0.0 ));
1087- Value one = b.create <ma::ConstantOp>(b.getFloatAttr (element_type, 1.0 ));
1088-
1089- Value expm1_real = b.create <mm::ExpM1Op>(real);
1090- Value exp_real = b.create <ma::AddFOp>(expm1_real, one);
1091-
1092- Value sin_imag = b.create <mm::SinOp>(imag);
1093- Value cosm1_imag = EmitCosm1 (imag, b);
1094- Value cos_imag = b.create <ma::AddFOp>(cosm1_imag, one);
1095-
1096- Value real_result = b.create <ma::AddFOp>(
1097- b.create <ma::MulFOp>(expm1_real, cos_imag), cosm1_imag);
1098-
1099- Value imag_is_zero =
1100- b.create <ma::CmpFOp>(ma::CmpFPredicate::OEQ, imag, zero);
1101- Value imag_result = b.create <ma::SelectOp>(
1102- imag_is_zero, zero, b.create <ma::MulFOp>(exp_real, sin_imag));
1103-
1104- rewriter.replaceOpWithNewOp <mc::CreateOp>(op, type, real_result,
1105- imag_result);
1106- return mlir::success ();
1107- }
1108-
1109- private:
1110- Value EmitCosm1 (Value arg, ImplicitLocOpBuilder& b) const {
1111- auto arg_type = mlir::cast<mlir::FloatType>(arg.getType ());
1112- auto negative_half =
1113- b.create <ma::ConstantOp>(b.getFloatAttr (arg_type, -0.5 ));
1114- auto negative_one =
1115- b.create <ma::ConstantOp>(b.getFloatAttr (arg_type, -1.0 ));
1116-
1117- // Algorithm copied from cephes cosm1:
1118- // cosm1(x) = -0.5 * x^2 + x^4 * P(x^2);
1119- // that is suitable when abs(x) < pi/4, otherwise we'll use cos(x)-1.
1120- //
1121- // This is an alternative algorithm
1122- // cosm1(x) = -2 * sin(x/2)^2
1123- // that is only slightly less accurate around abs(x) == 0.1 but
1124- // otherwise equivalent accuracy-wise compared to cephes cosm1.
1125- // However, we are not using it because it is notably less
1126- // performant than cephes cosm1.
1127-
1128- // TODO: define cosm1(x) as cosm1(x mod (2*pi)) to increase accuracy
1129- // for large x values that are close to 2*pi*n where n is some integer.
1130- static const std::array<double , 7 > kCoeffs {
1131- 4.7377507964246204691685E-14 , -1.1470284843425359765671E-11 ,
1132- 2.0876754287081521758361E-9 , -2.7557319214999787979814E-7 ,
1133- 2.4801587301570552304991E-5 , -1.3888888888888872993737E-3 ,
1134- 4.1666666666666666609054E-2 ,
1135- };
1136- Value cos = b.create <mm::CosOp>(arg);
1137- Value for_large_x = b.create <ma::AddFOp>(cos, negative_one);
1138-
1139- Value arg_pow_2 = b.create <ma::MulFOp>(arg, arg);
1140- Value arg_pow_4 = b.create <ma::MulFOp>(arg_pow_2, arg_pow_2);
1141- Value poly = EvaluatePolynomial (b, arg_pow_2, ArrayRef<double >(kCoeffs ));
1142-
1143- auto for_small_x =
1144- b.create <ma::AddFOp>(b.create <ma::MulFOp>(arg_pow_4, poly),
1145- b.create <ma::MulFOp>(negative_half, arg_pow_2));
1146-
1147- // (pi/4)^2 is approximately 0.61685
1148- Value cond = b.create <ma::CmpFOp>(
1149- ma::CmpFPredicate::OGE, arg_pow_2,
1150- b.create <ma::ConstantOp>(b.getFloatAttr (arg_type, 0.61685 )));
1151- return b.create <ma::SelectOp>(cond, for_large_x, for_small_x);
1152- }
1153- };
1154-
11551045class LowerTensorsPass : public impl ::LowerTensorsPassBase<LowerTensorsPass> {
11561046 public:
11571047 explicit LowerTensorsPass (const LowerTensorsPassOptions& options)
@@ -1162,7 +1052,7 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
11621052 mlir::RewritePatternSet tensor_patterns (mlir_context);
11631053 tensor_patterns.add <RewriteAtomicRMW>(mlir_context, is_amd_gpu_, gpu_arch_);
11641054 tensor_patterns
1165- .add <RewriteAllocateShared, RewriterExpm1Op, RewriteNonScalarConstants,
1055+ .add <RewriteAllocateShared, RewriteNonScalarConstants,
11661056 RewriteSyncThreads, RewriteTensorExtract, RewriteTransferRead,
11671057 RewriteTensorInsert, RewriteTransferWrite>(mlir_context);
11681058 if (mlir::failed (mlir::applyPatternsAndFoldGreedily (
0 commit comments