@@ -3525,114 +3525,123 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> {
35253525  }
35263526};
35273527
3528- // / Helper function for converting select ops. This function converts the
3529- // / signature of the given block. If the new block signature is different from
3530- // / `expectedTypes`, returns "failure".
3531- static  llvm::FailureOr<mlir::Block *>
3532- getConvertedBlock (mlir::ConversionPatternRewriter &rewriter,
3533-                   const  mlir::TypeConverter *converter,
3534-                   mlir::Operation *branchOp, mlir::Block *block,
3535-                   mlir::TypeRange expectedTypes) {
3536-   assert (converter && " expected non-null type converter"  );
3537-   assert (!block->isEntryBlock () && " entry blocks have no predecessors"  );
3538- 
3539-   //  There is nothing to do if the types already match.
3540-   if  (block->getArgumentTypes () == expectedTypes)
3541-     return  block;
3542- 
3543-   //  Compute the new block argument types and convert the block.
3544-   std::optional<mlir::TypeConverter::SignatureConversion> conversion =
3545-       converter->convertBlockSignature (block);
3546-   if  (!conversion)
3547-     return  rewriter.notifyMatchFailure (branchOp,
3548-                                        " could not compute block signature"  );
3549-   if  (expectedTypes != conversion->getConvertedTypes ())
3550-     return  rewriter.notifyMatchFailure (
3551-         branchOp,
3552-         " mismatch between adaptor operand types and computed block signature"  );
3553-   return  rewriter.applySignatureConversion (block, *conversion, converter);
3554- }
3555- 
3528+ // / Base class for SelectOpConversion and SelectRankOpConversion.
35563529template  <typename  OP>
3557- static  llvm::LogicalResult
3558- selectMatchAndRewrite (const  fir::LLVMTypeConverter &lowering, OP select,
3559-                       typename  OP::Adaptor adaptor,
3560-                       mlir::ConversionPatternRewriter &rewriter,
3561-                       const  mlir::TypeConverter *converter) {
3562-   unsigned  conds = select.getNumConditions ();
3563-   auto  cases = select.getCases ().getValue ();
3564-   mlir::Value selector = adaptor.getSelector ();
3565-   auto  loc = select.getLoc ();
3566-   assert (conds > 0  && " select must have cases"  );
3567- 
3568-   llvm::SmallVector<mlir::Block *> destinations;
3569-   llvm::SmallVector<mlir::ValueRange> destinationsOperands;
3570-   mlir::Block *defaultDestination;
3571-   mlir::ValueRange defaultOperands;
3572-   llvm::SmallVector<int32_t > caseValues;
3573- 
3574-   for  (unsigned  t = 0 ; t != conds; ++t) {
3575-     mlir::Block *dest = select.getSuccessor (t);
3576-     auto  destOps = select.getSuccessorOperands (adaptor.getOperands (), t);
3577-     const  mlir::Attribute &attr = cases[t];
3578-     if  (auto  intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
3579-       destinationsOperands.push_back (destOps ? *destOps : mlir::ValueRange{});
3580-       auto  convertedBlock =
3581-           getConvertedBlock (rewriter, converter, select, dest,
3582-                             mlir::TypeRange (destinationsOperands.back ()));
3530+ struct  SelectOpConversionBase  : public  fir ::FIROpConversion<OP> {
3531+   using  fir::FIROpConversion<OP>::FIROpConversion;
3532+ 
3533+ private: 
3534+   // / Helper function for converting select ops. This function converts the
3535+   // / signature of the given block. If the new block signature is different from
3536+   // / `expectedTypes`, returns "failure".
3537+   llvm::FailureOr<mlir::Block *>
3538+   getConvertedBlock (mlir::ConversionPatternRewriter &rewriter,
3539+                     mlir::Operation *branchOp, mlir::Block *block,
3540+                     mlir::TypeRange expectedTypes) const  {
3541+     const  mlir::TypeConverter *converter = this ->getTypeConverter ();
3542+     assert (converter && " expected non-null type converter"  );
3543+     assert (!block->isEntryBlock () && " entry blocks have no predecessors"  );
3544+ 
3545+     //  There is nothing to do if the types already match.
3546+     if  (block->getArgumentTypes () == expectedTypes)
3547+       return  block;
3548+ 
3549+     //  Compute the new block argument types and convert the block.
3550+     std::optional<mlir::TypeConverter::SignatureConversion> conversion =
3551+         converter->convertBlockSignature (block);
3552+     if  (!conversion)
3553+       return  rewriter.notifyMatchFailure (branchOp,
3554+                                          " could not compute block signature"  );
3555+     if  (expectedTypes != conversion->getConvertedTypes ())
3556+       return  rewriter.notifyMatchFailure (branchOp,
3557+                                          " mismatch between adaptor operand " 
3558+                                          " types and computed block signature"  );
3559+     return  rewriter.applySignatureConversion (block, *conversion, converter);
3560+   }
3561+ 
3562+ protected: 
3563+   llvm::LogicalResult
3564+   selectMatchAndRewrite (OP select, typename  OP::Adaptor adaptor,
3565+                         mlir::ConversionPatternRewriter &rewriter) const  {
3566+     unsigned  conds = select.getNumConditions ();
3567+     auto  cases = select.getCases ().getValue ();
3568+     mlir::Value selector = adaptor.getSelector ();
3569+     auto  loc = select.getLoc ();
3570+     assert (conds > 0  && " select must have cases"  );
3571+ 
3572+     llvm::SmallVector<mlir::Block *> destinations;
3573+     llvm::SmallVector<mlir::ValueRange> destinationsOperands;
3574+     mlir::Block *defaultDestination;
3575+     mlir::ValueRange defaultOperands;
3576+     //  LLVM::SwitchOp selector type and the case values types
3577+     //  must have the same bit width, so cast the selector to i64,
3578+     //  and use i64 for the case values. It is hard to imagine
3579+     //  a computed GO TO with the number of labels in the label-list
3580+     //  bigger than INT_MAX, but let's use i64 to be on the safe side.
3581+     //  Moreover, fir.select operation is more relaxed than
3582+     //  a Fortran computed GO TO, so it may specify such a case value
3583+     //  even if there is just a single label/case.
3584+     llvm::SmallVector<int64_t > caseValues;
3585+ 
3586+     for  (unsigned  t = 0 ; t != conds; ++t) {
3587+       mlir::Block *dest = select.getSuccessor (t);
3588+       auto  destOps = select.getSuccessorOperands (adaptor.getOperands (), t);
3589+       const  mlir::Attribute &attr = cases[t];
3590+       if  (auto  intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
3591+         destinationsOperands.push_back (destOps ? *destOps : mlir::ValueRange{});
3592+         auto  convertedBlock =
3593+             getConvertedBlock (rewriter, select, dest,
3594+                               mlir::TypeRange (destinationsOperands.back ()));
3595+         if  (mlir::failed (convertedBlock))
3596+           return  mlir::failure ();
3597+         destinations.push_back (*convertedBlock);
3598+         caseValues.push_back (intAttr.getInt ());
3599+         continue ;
3600+       }
3601+       assert (mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
3602+       assert ((t + 1  == conds) && " unit must be last"  );
3603+       defaultOperands = destOps ? *destOps : mlir::ValueRange{};
3604+       auto  convertedBlock = getConvertedBlock (rewriter, select, dest,
3605+                                               mlir::TypeRange (defaultOperands));
35833606      if  (mlir::failed (convertedBlock))
35843607        return  mlir::failure ();
3585-       destinations.push_back (*convertedBlock);
3586-       caseValues.push_back (intAttr.getInt ());
3587-       continue ;
3608+       defaultDestination = *convertedBlock;
35883609    }
3589-     assert (mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
3590-     assert ((t + 1  == conds) && " unit must be last"  );
3591-     defaultOperands = destOps ? *destOps : mlir::ValueRange{};
3592-     auto  convertedBlock = getConvertedBlock (rewriter, converter, select, dest,
3593-                                             mlir::TypeRange (defaultOperands));
3594-     if  (mlir::failed (convertedBlock))
3595-       return  mlir::failure ();
3596-     defaultDestination = *convertedBlock;
3597-   }
3598- 
3599-   //  LLVM::SwitchOp takes a i32 type for the selector.
3600-   if  (select.getSelector ().getType () != rewriter.getI32Type ())
3601-     selector = mlir::LLVM::TruncOp::create (rewriter, loc, rewriter.getI32Type (),
3602-                                            selector);
3603- 
3604-   rewriter.replaceOpWithNewOp <mlir::LLVM::SwitchOp>(
3605-       select, selector,
3606-       /* defaultDestination=*/  defaultDestination,
3607-       /* defaultOperands=*/  defaultOperands,
3608-       /* caseValues=*/  caseValues,
3609-       /* caseDestinations=*/  destinations,
3610-       /* caseOperands=*/  destinationsOperands,
3611-       /* branchWeights=*/  llvm::ArrayRef<std::int32_t >());
3612-   return  mlir::success ();
3613- }
36143610
3611+     selector =
3612+         this ->integerCast (loc, rewriter, rewriter.getI64Type (), selector);
3613+ 
3614+     rewriter.replaceOpWithNewOp <mlir::LLVM::SwitchOp>(
3615+         select, selector,
3616+         /* defaultDestination=*/  defaultDestination,
3617+         /* defaultOperands=*/  defaultOperands,
3618+         /* caseValues=*/  rewriter.getI64VectorAttr (caseValues),
3619+         /* caseDestinations=*/  destinations,
3620+         /* caseOperands=*/  destinationsOperands,
3621+         /* branchWeights=*/  llvm::ArrayRef<std::int32_t >());
3622+     return  mlir::success ();
3623+   }
3624+ };
36153625// / conversion of fir::SelectOp to an if-then-else ladder
3616- struct  SelectOpConversion  : public  fir ::FIROpConversion <fir::SelectOp> {
3617-   using  FIROpConversion::FIROpConversion ;
3626+ struct  SelectOpConversion  : public  SelectOpConversionBase <fir::SelectOp> {
3627+   using  SelectOpConversionBase::SelectOpConversionBase ;
36183628
36193629  llvm::LogicalResult
36203630  matchAndRewrite (fir::SelectOp op, OpAdaptor adaptor,
36213631                  mlir::ConversionPatternRewriter &rewriter) const  override  {
3622-     return  selectMatchAndRewrite<fir::SelectOp>(lowerTy (), op, adaptor,
3623-                                                 rewriter, getTypeConverter ());
3632+     return  this ->selectMatchAndRewrite (op, adaptor, rewriter);
36243633  }
36253634};
36263635
36273636// / conversion of fir::SelectRankOp to an if-then-else ladder
3628- struct  SelectRankOpConversion  : public  fir ::FIROpConversion<fir::SelectRankOp> {
3629-   using  FIROpConversion::FIROpConversion;
3637+ struct  SelectRankOpConversion 
3638+     : public SelectOpConversionBase<fir::SelectRankOp> {
3639+   using  SelectOpConversionBase::SelectOpConversionBase;
36303640
36313641  llvm::LogicalResult
36323642  matchAndRewrite (fir::SelectRankOp op, OpAdaptor adaptor,
36333643                  mlir::ConversionPatternRewriter &rewriter) const  override  {
3634-     return  selectMatchAndRewrite<fir::SelectRankOp>(
3635-         lowerTy (), op, adaptor, rewriter, getTypeConverter ());
3644+     return  this ->selectMatchAndRewrite (op, adaptor, rewriter);
36363645  }
36373646};
36383647
0 commit comments