@@ -3525,114 +3525,123 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> {
3525
3525
}
3526
3526
};
3527
3527
3528
- // / Helper function for converting select ops. This function converts the
3529
- // / signature of the given block. If the new block signature is different from
3530
- // / `expectedTypes`, returns "failure".
3531
- static llvm::FailureOr<mlir::Block *>
3532
- getConvertedBlock (mlir::ConversionPatternRewriter &rewriter,
3533
- const mlir::TypeConverter *converter,
3534
- mlir::Operation *branchOp, mlir::Block *block,
3535
- mlir::TypeRange expectedTypes) {
3536
- assert (converter && " expected non-null type converter" );
3537
- assert (!block->isEntryBlock () && " entry blocks have no predecessors" );
3538
-
3539
- // There is nothing to do if the types already match.
3540
- if (block->getArgumentTypes () == expectedTypes)
3541
- return block;
3542
-
3543
- // Compute the new block argument types and convert the block.
3544
- std::optional<mlir::TypeConverter::SignatureConversion> conversion =
3545
- converter->convertBlockSignature (block);
3546
- if (!conversion)
3547
- return rewriter.notifyMatchFailure (branchOp,
3548
- " could not compute block signature" );
3549
- if (expectedTypes != conversion->getConvertedTypes ())
3550
- return rewriter.notifyMatchFailure (
3551
- branchOp,
3552
- " mismatch between adaptor operand types and computed block signature" );
3553
- return rewriter.applySignatureConversion (block, *conversion, converter);
3554
- }
3555
-
3528
+ // / Base class for SelectOpConversion and SelectRankOpConversion.
3556
3529
template <typename OP>
3557
- static llvm::LogicalResult
3558
- selectMatchAndRewrite (const fir::LLVMTypeConverter &lowering, OP select,
3559
- typename OP::Adaptor adaptor,
3560
- mlir::ConversionPatternRewriter &rewriter,
3561
- const mlir::TypeConverter *converter) {
3562
- unsigned conds = select.getNumConditions ();
3563
- auto cases = select.getCases ().getValue ();
3564
- mlir::Value selector = adaptor.getSelector ();
3565
- auto loc = select.getLoc ();
3566
- assert (conds > 0 && " select must have cases" );
3567
-
3568
- llvm::SmallVector<mlir::Block *> destinations;
3569
- llvm::SmallVector<mlir::ValueRange> destinationsOperands;
3570
- mlir::Block *defaultDestination;
3571
- mlir::ValueRange defaultOperands;
3572
- llvm::SmallVector<int32_t > caseValues;
3573
-
3574
- for (unsigned t = 0 ; t != conds; ++t) {
3575
- mlir::Block *dest = select.getSuccessor (t);
3576
- auto destOps = select.getSuccessorOperands (adaptor.getOperands (), t);
3577
- const mlir::Attribute &attr = cases[t];
3578
- if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
3579
- destinationsOperands.push_back (destOps ? *destOps : mlir::ValueRange{});
3580
- auto convertedBlock =
3581
- getConvertedBlock (rewriter, converter, select, dest,
3582
- mlir::TypeRange (destinationsOperands.back ()));
3530
+ struct SelectOpConversionBase : public fir ::FIROpConversion<OP> {
3531
+ using fir::FIROpConversion<OP>::FIROpConversion;
3532
+
3533
+ private:
3534
+ // / Helper function for converting select ops. This function converts the
3535
+ // / signature of the given block. If the new block signature is different from
3536
+ // / `expectedTypes`, returns "failure".
3537
+ llvm::FailureOr<mlir::Block *>
3538
+ getConvertedBlock (mlir::ConversionPatternRewriter &rewriter,
3539
+ mlir::Operation *branchOp, mlir::Block *block,
3540
+ mlir::TypeRange expectedTypes) const {
3541
+ const mlir::TypeConverter *converter = this ->getTypeConverter ();
3542
+ assert (converter && " expected non-null type converter" );
3543
+ assert (!block->isEntryBlock () && " entry blocks have no predecessors" );
3544
+
3545
+ // There is nothing to do if the types already match.
3546
+ if (block->getArgumentTypes () == expectedTypes)
3547
+ return block;
3548
+
3549
+ // Compute the new block argument types and convert the block.
3550
+ std::optional<mlir::TypeConverter::SignatureConversion> conversion =
3551
+ converter->convertBlockSignature (block);
3552
+ if (!conversion)
3553
+ return rewriter.notifyMatchFailure (branchOp,
3554
+ " could not compute block signature" );
3555
+ if (expectedTypes != conversion->getConvertedTypes ())
3556
+ return rewriter.notifyMatchFailure (branchOp,
3557
+ " mismatch between adaptor operand "
3558
+ " types and computed block signature" );
3559
+ return rewriter.applySignatureConversion (block, *conversion, converter);
3560
+ }
3561
+
3562
+ protected:
3563
+ llvm::LogicalResult
3564
+ selectMatchAndRewrite (OP select, typename OP::Adaptor adaptor,
3565
+ mlir::ConversionPatternRewriter &rewriter) const {
3566
+ unsigned conds = select.getNumConditions ();
3567
+ auto cases = select.getCases ().getValue ();
3568
+ mlir::Value selector = adaptor.getSelector ();
3569
+ auto loc = select.getLoc ();
3570
+ assert (conds > 0 && " select must have cases" );
3571
+
3572
+ llvm::SmallVector<mlir::Block *> destinations;
3573
+ llvm::SmallVector<mlir::ValueRange> destinationsOperands;
3574
+ mlir::Block *defaultDestination;
3575
+ mlir::ValueRange defaultOperands;
3576
+ // LLVM::SwitchOp selector type and the case values types
3577
+ // must have the same bit width, so cast the selector to i64,
3578
+ // and use i64 for the case values. It is hard to imagine
3579
+ // a computed GO TO with the number of labels in the label-list
3580
+ // bigger than INT_MAX, but let's use i64 to be on the safe side.
3581
+ // Moreover, fir.select operation is more relaxed than
3582
+ // a Fortran computed GO TO, so it may specify such a case value
3583
+ // even if there is just a single label/case.
3584
+ llvm::SmallVector<int64_t > caseValues;
3585
+
3586
+ for (unsigned t = 0 ; t != conds; ++t) {
3587
+ mlir::Block *dest = select.getSuccessor (t);
3588
+ auto destOps = select.getSuccessorOperands (adaptor.getOperands (), t);
3589
+ const mlir::Attribute &attr = cases[t];
3590
+ if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
3591
+ destinationsOperands.push_back (destOps ? *destOps : mlir::ValueRange{});
3592
+ auto convertedBlock =
3593
+ getConvertedBlock (rewriter, select, dest,
3594
+ mlir::TypeRange (destinationsOperands.back ()));
3595
+ if (mlir::failed (convertedBlock))
3596
+ return mlir::failure ();
3597
+ destinations.push_back (*convertedBlock);
3598
+ caseValues.push_back (intAttr.getInt ());
3599
+ continue ;
3600
+ }
3601
+ assert (mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
3602
+ assert ((t + 1 == conds) && " unit must be last" );
3603
+ defaultOperands = destOps ? *destOps : mlir::ValueRange{};
3604
+ auto convertedBlock = getConvertedBlock (rewriter, select, dest,
3605
+ mlir::TypeRange (defaultOperands));
3583
3606
if (mlir::failed (convertedBlock))
3584
3607
return mlir::failure ();
3585
- destinations.push_back (*convertedBlock);
3586
- caseValues.push_back (intAttr.getInt ());
3587
- continue ;
3608
+ defaultDestination = *convertedBlock;
3588
3609
}
3589
- assert (mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
3590
- assert ((t + 1 == conds) && " unit must be last" );
3591
- defaultOperands = destOps ? *destOps : mlir::ValueRange{};
3592
- auto convertedBlock = getConvertedBlock (rewriter, converter, select, dest,
3593
- mlir::TypeRange (defaultOperands));
3594
- if (mlir::failed (convertedBlock))
3595
- return mlir::failure ();
3596
- defaultDestination = *convertedBlock;
3597
- }
3598
-
3599
- // LLVM::SwitchOp takes a i32 type for the selector.
3600
- if (select.getSelector ().getType () != rewriter.getI32Type ())
3601
- selector = mlir::LLVM::TruncOp::create (rewriter, loc, rewriter.getI32Type (),
3602
- selector);
3603
-
3604
- rewriter.replaceOpWithNewOp <mlir::LLVM::SwitchOp>(
3605
- select, selector,
3606
- /* defaultDestination=*/ defaultDestination,
3607
- /* defaultOperands=*/ defaultOperands,
3608
- /* caseValues=*/ caseValues,
3609
- /* caseDestinations=*/ destinations,
3610
- /* caseOperands=*/ destinationsOperands,
3611
- /* branchWeights=*/ llvm::ArrayRef<std::int32_t >());
3612
- return mlir::success ();
3613
- }
3614
3610
3611
+ selector =
3612
+ this ->integerCast (loc, rewriter, rewriter.getI64Type (), selector);
3613
+
3614
+ rewriter.replaceOpWithNewOp <mlir::LLVM::SwitchOp>(
3615
+ select, selector,
3616
+ /* defaultDestination=*/ defaultDestination,
3617
+ /* defaultOperands=*/ defaultOperands,
3618
+ /* caseValues=*/ rewriter.getI64VectorAttr (caseValues),
3619
+ /* caseDestinations=*/ destinations,
3620
+ /* caseOperands=*/ destinationsOperands,
3621
+ /* branchWeights=*/ llvm::ArrayRef<std::int32_t >());
3622
+ return mlir::success ();
3623
+ }
3624
+ };
3615
3625
// / conversion of fir::SelectOp to an if-then-else ladder
3616
- struct SelectOpConversion : public fir ::FIROpConversion <fir::SelectOp> {
3617
- using FIROpConversion::FIROpConversion ;
3626
+ struct SelectOpConversion : public SelectOpConversionBase <fir::SelectOp> {
3627
+ using SelectOpConversionBase::SelectOpConversionBase ;
3618
3628
3619
3629
llvm::LogicalResult
3620
3630
matchAndRewrite (fir::SelectOp op, OpAdaptor adaptor,
3621
3631
mlir::ConversionPatternRewriter &rewriter) const override {
3622
- return selectMatchAndRewrite<fir::SelectOp>(lowerTy (), op, adaptor,
3623
- rewriter, getTypeConverter ());
3632
+ return this ->selectMatchAndRewrite (op, adaptor, rewriter);
3624
3633
}
3625
3634
};
3626
3635
3627
3636
// / conversion of fir::SelectRankOp to an if-then-else ladder
3628
- struct SelectRankOpConversion : public fir ::FIROpConversion<fir::SelectRankOp> {
3629
- using FIROpConversion::FIROpConversion;
3637
+ struct SelectRankOpConversion
3638
+ : public SelectOpConversionBase<fir::SelectRankOp> {
3639
+ using SelectOpConversionBase::SelectOpConversionBase;
3630
3640
3631
3641
llvm::LogicalResult
3632
3642
matchAndRewrite (fir::SelectRankOp op, OpAdaptor adaptor,
3633
3643
mlir::ConversionPatternRewriter &rewriter) const override {
3634
- return selectMatchAndRewrite<fir::SelectRankOp>(
3635
- lowerTy (), op, adaptor, rewriter, getTypeConverter ());
3644
+ return this ->selectMatchAndRewrite (op, adaptor, rewriter);
3636
3645
}
3637
3646
};
3638
3647
0 commit comments