1818#include " mlir/IR/TypeUtilities.h"
1919#include " mlir/Pass/Pass.h"
2020
21+ #include " llvm/ADT/FloatingPointMode.h"
22+
2123namespace mlir {
2224#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
2325#include " mlir/Conversion/Passes.h.inc"
@@ -286,6 +288,40 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
286288 }
287289};
288290
291+ struct IsNaNOpLowering : public ConvertOpToLLVMPattern <math::IsNaNOp> {
292+ using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
293+
294+ LogicalResult
295+ matchAndRewrite (math::IsNaNOp op, OpAdaptor adaptor,
296+ ConversionPatternRewriter &rewriter) const override {
297+ auto operandType = adaptor.getOperand ().getType ();
298+
299+ if (!operandType || !LLVM::isCompatibleType (operandType))
300+ return failure ();
301+
302+ rewriter.replaceOpWithNewOp <LLVM::IsFPClass>(
303+ op, op.getType (), adaptor.getOperand (), llvm::fcNan);
304+ return success ();
305+ }
306+ };
307+
308+ struct IsFiniteOpLowering : public ConvertOpToLLVMPattern <math::IsFiniteOp> {
309+ using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
310+
311+ LogicalResult
312+ matchAndRewrite (math::IsFiniteOp op, OpAdaptor adaptor,
313+ ConversionPatternRewriter &rewriter) const override {
314+ auto operandType = adaptor.getOperand ().getType ();
315+
316+ if (!operandType || !LLVM::isCompatibleType (operandType))
317+ return failure ();
318+
319+ rewriter.replaceOpWithNewOp <LLVM::IsFPClass>(
320+ op, op.getType (), adaptor.getOperand (), llvm::fcFinite);
321+ return success ();
322+ }
323+ };
324+
289325struct ConvertMathToLLVMPass
290326 : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
291327 using Base::Base;
@@ -309,6 +345,8 @@ void mlir::populateMathToLLVMConversionPatterns(
309345 patterns.add <Log1pOpLowering>(converter, benefit);
310346 // clang-format off
311347 patterns.add <
348+ IsNaNOpLowering,
349+ IsFiniteOpLowering,
312350 AbsFOpLowering,
313351 AbsIOpLowering,
314352 CeilOpLowering,
0 commit comments