@@ -238,6 +238,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
238238 ConversionPatternRewriter &rewriter) const override ;
239239};
240240
241+ struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern <arith::SelectOp> {
242+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
243+ using Adaptor =
244+ typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
245+
246+ LogicalResult
247+ matchAndRewrite (arith::SelectOp op, Adaptor adaptor,
248+ ConversionPatternRewriter &rewriter) const override ;
249+ };
250+
241251} // namespace
242252
243253// ===----------------------------------------------------------------------===//
@@ -479,6 +489,32 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
479489 rewriter);
480490}
481491
492+ // ===----------------------------------------------------------------------===//
493+ // SelectOpOneToNLowering
494+ // ===----------------------------------------------------------------------===//
495+
496+ // / Pattern for arith.select where the true/false values lower to multiple
497+ // / SSA values (1:N conversion). This pattern generates multiple arith.select
498+ // / than can be lowered by the 1:1 arith.select pattern.
499+ LogicalResult SelectOpOneToNLowering::matchAndRewrite (
500+ arith::SelectOp op, Adaptor adaptor,
501+ ConversionPatternRewriter &rewriter) const {
502+ // In case of a 1:1 conversion, the 1:1 pattern will match.
503+ if (llvm::hasSingleElement (adaptor.getTrueValue ()))
504+ return rewriter.notifyMatchFailure (
505+ op, " not a 1:N conversion, 1:1 pattern will match" );
506+ if (!op.getCondition ().getType ().isInteger (1 ))
507+ return rewriter.notifyMatchFailure (op,
508+ " non-i1 conditions are not supported" );
509+ SmallVector<Value> results;
510+ for (auto [trueValue, falseValue] :
511+ llvm::zip_equal (adaptor.getTrueValue (), adaptor.getFalseValue ()))
512+ results.push_back (arith::SelectOp::create (
513+ rewriter, op.getLoc (), op.getCondition (), trueValue, falseValue));
514+ rewriter.replaceOpWithMultiple (op, {results});
515+ return success ();
516+ }
517+
482518// ===----------------------------------------------------------------------===//
483519// Pass Definition
484520// ===----------------------------------------------------------------------===//
@@ -587,6 +623,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
587623 RemSIOpLowering,
588624 RemUIOpLowering,
589625 SelectOpLowering,
626+ SelectOpOneToNLowering,
590627 ShLIOpLowering,
591628 ShRSIOpLowering,
592629 ShRUIOpLowering,
0 commit comments