@@ -286,6 +286,40 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
286286 }
287287};
288288
289+ struct IsNaNOpLowering : public ConvertOpToLLVMPattern <math::IsNaNOp> {
290+ using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
291+
292+ LogicalResult
293+ matchAndRewrite (math::IsNaNOp op, OpAdaptor adaptor,
294+ ConversionPatternRewriter &rewriter) const override {
295+ auto operandType = adaptor.getOperand ().getType ();
296+
297+ if (!operandType || !LLVM::isCompatibleType (operandType))
298+ return failure ();
299+
300+ rewriter.replaceOpWithNewOp <LLVM::IsFPClass>(op, op.getType (),
301+ adaptor.getOperand (), 3 );
302+ return success ();
303+ }
304+ };
305+
306+ struct IsFiniteOpLowering : public ConvertOpToLLVMPattern <math::IsFiniteOp> {
307+ using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
308+
309+ LogicalResult
310+ matchAndRewrite (math::IsFiniteOp op, OpAdaptor adaptor,
311+ ConversionPatternRewriter &rewriter) const override {
312+ auto operandType = adaptor.getOperand ().getType ();
313+
314+ if (!operandType || !LLVM::isCompatibleType (operandType))
315+ return failure ();
316+
317+ rewriter.replaceOpWithNewOp <LLVM::IsFPClass>(op, op.getType (),
318+ adaptor.getOperand (), 504 );
319+ return success ();
320+ }
321+ };
322+
289323struct ConvertMathToLLVMPass
290324 : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
291325 using Base::Base;
@@ -307,6 +341,8 @@ void mlir::populateMathToLLVMConversionPatterns(
307341 bool approximateLog1p, PatternBenefit benefit) {
308342 if (approximateLog1p)
309343 patterns.add <Log1pOpLowering>(converter, benefit);
344+ patterns.add <IsNaNOpLowering, IsFiniteOpLowering>(converter);
345+
310346 // clang-format off
311347 patterns.add <
312348 AbsFOpLowering,
0 commit comments