@@ -238,6 +238,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
238
238
ConversionPatternRewriter &rewriter) const override ;
239
239
};
240
240
241
+ struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern <arith::SelectOp> {
242
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
243
+ using Adaptor =
244
+ typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
245
+
246
+ LogicalResult
247
+ matchAndRewrite (arith::SelectOp op, Adaptor adaptor,
248
+ ConversionPatternRewriter &rewriter) const override ;
249
+ };
250
+
241
251
} // namespace
242
252
243
253
// ===----------------------------------------------------------------------===//
@@ -479,6 +489,32 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
479
489
rewriter);
480
490
}
481
491
492
+ // ===----------------------------------------------------------------------===//
493
+ // SelectOpOneToNLowering
494
+ // ===----------------------------------------------------------------------===//
495
+
496
+ // / Pattern for arith.select where the true/false values lower to multiple
497
+ // / SSA values (1:N conversion). This pattern generates multiple arith.select
498
+ // / than can be lowered by the 1:1 arith.select pattern.
499
+ LogicalResult SelectOpOneToNLowering::matchAndRewrite (
500
+ arith::SelectOp op, Adaptor adaptor,
501
+ ConversionPatternRewriter &rewriter) const {
502
+ // In case of a 1:1 conversion, the 1:1 pattern will match.
503
+ if (llvm::hasSingleElement (adaptor.getTrueValue ()))
504
+ return rewriter.notifyMatchFailure (
505
+ op, " not a 1:N conversion, 1:1 pattern will match" );
506
+ if (!op.getCondition ().getType ().isInteger (1 ))
507
+ return rewriter.notifyMatchFailure (op,
508
+ " non-i1 conditions are not supported" );
509
+ SmallVector<Value> results;
510
+ for (auto [trueValue, falseValue] :
511
+ llvm::zip_equal (adaptor.getTrueValue (), adaptor.getFalseValue ()))
512
+ results.push_back (arith::SelectOp::create (
513
+ rewriter, op.getLoc (), op.getCondition (), trueValue, falseValue));
514
+ rewriter.replaceOpWithMultiple (op, {results});
515
+ return success ();
516
+ }
517
+
482
518
// ===----------------------------------------------------------------------===//
483
519
// Pass Definition
484
520
// ===----------------------------------------------------------------------===//
@@ -587,6 +623,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
587
623
RemSIOpLowering,
588
624
RemUIOpLowering,
589
625
SelectOpLowering,
626
+ SelectOpOneToNLowering,
590
627
ShLIOpLowering,
591
628
ShRSIOpLowering,
592
629
ShRUIOpLowering,
0 commit comments