Skip to content

Commit b8e4232

Browse files
authored
[flang] Cast fir.select[_rank] selector to i64. (#153239)
Properly cast the selector to `i64` regardless of its integer type. We used to generate llvm.trunc always. We have to use `i64` as long as the case values may exceed INT_MAX. Fixes #153050.
1 parent 6032ff6 commit b8e4232

File tree

3 files changed

+158
-98
lines changed

3 files changed

+158
-98
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 99 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3525,114 +3525,123 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> {
35253525
}
35263526
};
35273527

3528-
/// Helper function for converting select ops. This function converts the
3529-
/// signature of the given block. If the new block signature is different from
3530-
/// `expectedTypes`, returns "failure".
3531-
static llvm::FailureOr<mlir::Block *>
3532-
getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
3533-
const mlir::TypeConverter *converter,
3534-
mlir::Operation *branchOp, mlir::Block *block,
3535-
mlir::TypeRange expectedTypes) {
3536-
assert(converter && "expected non-null type converter");
3537-
assert(!block->isEntryBlock() && "entry blocks have no predecessors");
3538-
3539-
// There is nothing to do if the types already match.
3540-
if (block->getArgumentTypes() == expectedTypes)
3541-
return block;
3542-
3543-
// Compute the new block argument types and convert the block.
3544-
std::optional<mlir::TypeConverter::SignatureConversion> conversion =
3545-
converter->convertBlockSignature(block);
3546-
if (!conversion)
3547-
return rewriter.notifyMatchFailure(branchOp,
3548-
"could not compute block signature");
3549-
if (expectedTypes != conversion->getConvertedTypes())
3550-
return rewriter.notifyMatchFailure(
3551-
branchOp,
3552-
"mismatch between adaptor operand types and computed block signature");
3553-
return rewriter.applySignatureConversion(block, *conversion, converter);
3554-
}
3555-
3528+
/// Base class for SelectOpConversion and SelectRankOpConversion.
35563529
template <typename OP>
3557-
static llvm::LogicalResult
3558-
selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering, OP select,
3559-
typename OP::Adaptor adaptor,
3560-
mlir::ConversionPatternRewriter &rewriter,
3561-
const mlir::TypeConverter *converter) {
3562-
unsigned conds = select.getNumConditions();
3563-
auto cases = select.getCases().getValue();
3564-
mlir::Value selector = adaptor.getSelector();
3565-
auto loc = select.getLoc();
3566-
assert(conds > 0 && "select must have cases");
3567-
3568-
llvm::SmallVector<mlir::Block *> destinations;
3569-
llvm::SmallVector<mlir::ValueRange> destinationsOperands;
3570-
mlir::Block *defaultDestination;
3571-
mlir::ValueRange defaultOperands;
3572-
llvm::SmallVector<int32_t> caseValues;
3573-
3574-
for (unsigned t = 0; t != conds; ++t) {
3575-
mlir::Block *dest = select.getSuccessor(t);
3576-
auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
3577-
const mlir::Attribute &attr = cases[t];
3578-
if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
3579-
destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
3580-
auto convertedBlock =
3581-
getConvertedBlock(rewriter, converter, select, dest,
3582-
mlir::TypeRange(destinationsOperands.back()));
3530+
struct SelectOpConversionBase : public fir::FIROpConversion<OP> {
3531+
using fir::FIROpConversion<OP>::FIROpConversion;
3532+
3533+
private:
3534+
/// Helper function for converting select ops. This function converts the
3535+
/// signature of the given block. If the new block signature is different from
3536+
/// `expectedTypes`, returns "failure".
3537+
llvm::FailureOr<mlir::Block *>
3538+
getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
3539+
mlir::Operation *branchOp, mlir::Block *block,
3540+
mlir::TypeRange expectedTypes) const {
3541+
const mlir::TypeConverter *converter = this->getTypeConverter();
3542+
assert(converter && "expected non-null type converter");
3543+
assert(!block->isEntryBlock() && "entry blocks have no predecessors");
3544+
3545+
// There is nothing to do if the types already match.
3546+
if (block->getArgumentTypes() == expectedTypes)
3547+
return block;
3548+
3549+
// Compute the new block argument types and convert the block.
3550+
std::optional<mlir::TypeConverter::SignatureConversion> conversion =
3551+
converter->convertBlockSignature(block);
3552+
if (!conversion)
3553+
return rewriter.notifyMatchFailure(branchOp,
3554+
"could not compute block signature");
3555+
if (expectedTypes != conversion->getConvertedTypes())
3556+
return rewriter.notifyMatchFailure(branchOp,
3557+
"mismatch between adaptor operand "
3558+
"types and computed block signature");
3559+
return rewriter.applySignatureConversion(block, *conversion, converter);
3560+
}
3561+
3562+
protected:
3563+
llvm::LogicalResult
3564+
selectMatchAndRewrite(OP select, typename OP::Adaptor adaptor,
3565+
mlir::ConversionPatternRewriter &rewriter) const {
3566+
unsigned conds = select.getNumConditions();
3567+
auto cases = select.getCases().getValue();
3568+
mlir::Value selector = adaptor.getSelector();
3569+
auto loc = select.getLoc();
3570+
assert(conds > 0 && "select must have cases");
3571+
3572+
llvm::SmallVector<mlir::Block *> destinations;
3573+
llvm::SmallVector<mlir::ValueRange> destinationsOperands;
3574+
mlir::Block *defaultDestination;
3575+
mlir::ValueRange defaultOperands;
3576+
// LLVM::SwitchOp selector type and the case values types
3577+
// must have the same bit width, so cast the selector to i64,
3578+
// and use i64 for the case values. It is hard to imagine
3579+
// a computed GO TO with the number of labels in the label-list
3580+
// bigger than INT_MAX, but let's use i64 to be on the safe side.
3581+
// Moreover, fir.select operation is more relaxed than
3582+
// a Fortran computed GO TO, so it may specify such a case value
3583+
// even if there is just a single label/case.
3584+
llvm::SmallVector<int64_t> caseValues;
3585+
3586+
for (unsigned t = 0; t != conds; ++t) {
3587+
mlir::Block *dest = select.getSuccessor(t);
3588+
auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
3589+
const mlir::Attribute &attr = cases[t];
3590+
if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
3591+
destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
3592+
auto convertedBlock =
3593+
getConvertedBlock(rewriter, select, dest,
3594+
mlir::TypeRange(destinationsOperands.back()));
3595+
if (mlir::failed(convertedBlock))
3596+
return mlir::failure();
3597+
destinations.push_back(*convertedBlock);
3598+
caseValues.push_back(intAttr.getInt());
3599+
continue;
3600+
}
3601+
assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
3602+
assert((t + 1 == conds) && "unit must be last");
3603+
defaultOperands = destOps ? *destOps : mlir::ValueRange{};
3604+
auto convertedBlock = getConvertedBlock(rewriter, select, dest,
3605+
mlir::TypeRange(defaultOperands));
35833606
if (mlir::failed(convertedBlock))
35843607
return mlir::failure();
3585-
destinations.push_back(*convertedBlock);
3586-
caseValues.push_back(intAttr.getInt());
3587-
continue;
3608+
defaultDestination = *convertedBlock;
35883609
}
3589-
assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
3590-
assert((t + 1 == conds) && "unit must be last");
3591-
defaultOperands = destOps ? *destOps : mlir::ValueRange{};
3592-
auto convertedBlock = getConvertedBlock(rewriter, converter, select, dest,
3593-
mlir::TypeRange(defaultOperands));
3594-
if (mlir::failed(convertedBlock))
3595-
return mlir::failure();
3596-
defaultDestination = *convertedBlock;
3597-
}
3598-
3599-
// LLVM::SwitchOp takes a i32 type for the selector.
3600-
if (select.getSelector().getType() != rewriter.getI32Type())
3601-
selector = mlir::LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(),
3602-
selector);
3603-
3604-
rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
3605-
select, selector,
3606-
/*defaultDestination=*/defaultDestination,
3607-
/*defaultOperands=*/defaultOperands,
3608-
/*caseValues=*/caseValues,
3609-
/*caseDestinations=*/destinations,
3610-
/*caseOperands=*/destinationsOperands,
3611-
/*branchWeights=*/llvm::ArrayRef<std::int32_t>());
3612-
return mlir::success();
3613-
}
36143610

3611+
selector =
3612+
this->integerCast(loc, rewriter, rewriter.getI64Type(), selector);
3613+
3614+
rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
3615+
select, selector,
3616+
/*defaultDestination=*/defaultDestination,
3617+
/*defaultOperands=*/defaultOperands,
3618+
/*caseValues=*/rewriter.getI64VectorAttr(caseValues),
3619+
/*caseDestinations=*/destinations,
3620+
/*caseOperands=*/destinationsOperands,
3621+
/*branchWeights=*/llvm::ArrayRef<std::int32_t>());
3622+
return mlir::success();
3623+
}
3624+
};
36153625
/// conversion of fir::SelectOp to an if-then-else ladder
3616-
struct SelectOpConversion : public fir::FIROpConversion<fir::SelectOp> {
3617-
using FIROpConversion::FIROpConversion;
3626+
struct SelectOpConversion : public SelectOpConversionBase<fir::SelectOp> {
3627+
using SelectOpConversionBase::SelectOpConversionBase;
36183628

36193629
llvm::LogicalResult
36203630
matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
36213631
mlir::ConversionPatternRewriter &rewriter) const override {
3622-
return selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor,
3623-
rewriter, getTypeConverter());
3632+
return this->selectMatchAndRewrite(op, adaptor, rewriter);
36243633
}
36253634
};
36263635

36273636
/// conversion of fir::SelectRankOp to an if-then-else ladder
3628-
struct SelectRankOpConversion : public fir::FIROpConversion<fir::SelectRankOp> {
3629-
using FIROpConversion::FIROpConversion;
3637+
struct SelectRankOpConversion
3638+
: public SelectOpConversionBase<fir::SelectRankOp> {
3639+
using SelectOpConversionBase::SelectOpConversionBase;
36303640

36313641
llvm::LogicalResult
36323642
matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
36333643
mlir::ConversionPatternRewriter &rewriter) const override {
3634-
return selectMatchAndRewrite<fir::SelectRankOp>(
3635-
lowerTy(), op, adaptor, rewriter, getTypeConverter());
3644+
return this->selectMatchAndRewrite(op, adaptor, rewriter);
36363645
}
36373646
};
36383647

flang/test/Fir/convert-to-llvm.fir

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,7 @@ func.func @select(%arg : index, %arg2 : i32) -> i32 {
338338
// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
339339
// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
340340
// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
341-
// CHECK: %[[SELECTOR:.*]] = llvm.trunc %[[SELECTVALUE]] : i{{.*}} to i32
342-
// CHECK: llvm.switch %[[SELECTOR]] : i32, ^bb5 [
341+
// CHECK: llvm.switch %[[SELECTVALUE]] : i64, ^bb5 [
343342
// CHECK: 1: ^bb1(%[[C0]] : i32),
344343
// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, [[IDX]], i32),
345344
// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
@@ -384,7 +383,8 @@ func.func @select_rank(%arg : i32, %arg2 : i32) -> i32 {
384383
// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
385384
// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
386385
// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
387-
// CHECK: llvm.switch %[[SELECTVALUE]] : i32, ^bb5 [
386+
// CHECK: %[[SELECTOR:.*]] = llvm.sext %[[SELECTVALUE]] : i{{.*}} to i64
387+
// CHECK: llvm.switch %[[SELECTOR]] : i64, ^bb5 [
388388
// CHECK: 1: ^bb1(%[[C0]] : i32),
389389
// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, i32, i32),
390390
// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
@@ -2853,23 +2853,74 @@ func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
28532853
return
28542854
}
28552855

2856+
// -----
2857+
28562858
// CHECK-LABEL: @test_call_arg_attrs_indirect
28572859
func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
28582860
// CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
28592861
%0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
28602862
return %0 : i16
28612863
}
28622864

2865+
// -----
2866+
28632867
// CHECK-LABEL: @test_byval
28642868
func.func @test_byval(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
28652869
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.byval = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
28662870
fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.byval = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
28672871
return
28682872
}
28692873

2874+
// -----
2875+
28702876
// CHECK-LABEL: @test_sret
28712877
func.func @test_sret(%arg0: (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, f64) -> (), %arg1: !fir.ref<!fir.type<t{a:!fir.array<5xf64>}>>, %arg2: f64) {
28722878
// llvm.call %{{.*}}(%{{.*}}, %{{.*}}) : !llvm.ptr, (!llvm.ptr {llvm.sret = !llvm.struct<"t", (array<5 x f64>)>}, f64) -> ()
28732879
fir.call %arg0(%arg1, %arg2) : (!fir.ref<!fir.type<t{a:!fir.array<5xf64>}>> {llvm.sret = !fir.type<t{a:!fir.array<5xf64>}>}, f64) -> ()
28742880
return
28752881
}
2882+
2883+
// -----
2884+
2885+
func.func @select_with_cast(%arg1 : i8, %arg2 : i16, %arg3: i64, %arg4: index) -> () {
2886+
fir.select %arg1 : i8 [ 1, ^bb1, unit, ^bb1 ]
2887+
^bb1:
2888+
fir.select %arg2 : i16 [ 1, ^bb2, unit, ^bb2 ]
2889+
^bb2:
2890+
fir.select %arg3 : i64 [ 1, ^bb3, unit, ^bb3 ]
2891+
^bb3:
2892+
fir.select %arg4 : index [ 1, ^bb4, unit, ^bb4 ]
2893+
^bb4:
2894+
fir.select %arg3 : i64 [ 4294967296, ^bb5, unit, ^bb5 ]
2895+
^bb5:
2896+
return
2897+
}
2898+
// CHECK-LABEL: llvm.func @select_with_cast(
2899+
// CHECK-SAME: %[[ARG0:.*]]: i8,
2900+
// CHECK-SAME: %[[ARG1:.*]]: i16,
2901+
// CHECK-SAME: %[[ARG2:.*]]: i64,
2902+
// CHECK-SAME: %[[ARG3:.*]]: i64) {
2903+
// CHECK: %[[VAL_0:.*]] = llvm.sext %[[ARG0]] : i8 to i64
2904+
// CHECK: llvm.switch %[[VAL_0]] : i64, ^bb1 [
2905+
// CHECK: 1: ^bb1
2906+
// CHECK: ]
2907+
// CHECK: ^bb1:
2908+
// CHECK: %[[VAL_1:.*]] = llvm.sext %[[ARG1]] : i16 to i64
2909+
// CHECK: llvm.switch %[[VAL_1]] : i64, ^bb2 [
2910+
// CHECK: 1: ^bb2
2911+
// CHECK: ]
2912+
// CHECK: ^bb2:
2913+
// CHECK: llvm.switch %[[ARG2]] : i64, ^bb3 [
2914+
// CHECK: 1: ^bb3
2915+
// CHECK: ]
2916+
// CHECK: ^bb3:
2917+
// CHECK: llvm.switch %[[ARG3]] : i64, ^bb4 [
2918+
// CHECK: 1: ^bb4
2919+
// CHECK: ]
2920+
// CHECK: ^bb4:
2921+
// CHECK: llvm.switch %[[ARG2]] : i64, ^bb5 [
2922+
// CHECK: 4294967296: ^bb5
2923+
// CHECK: ]
2924+
// CHECK: ^bb5:
2925+
// CHECK: llvm.return
2926+
// CHECK: }

flang/test/Fir/select.fir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
func.func @f(%a : i32) -> i32 {
88
%1 = arith.constant 1 : i32
99
%2 = arith.constant 42 : i32
10-
// CHECK: switch i32 %{{.*}}, label %{{.*}} [
11-
// CHECK: i32 1, label %{{.*}}
10+
// CHECK: switch i64 %{{.*}}, label %{{.*}} [
11+
// CHECK: i64 1, label %{{.*}}
1212
// CHECK: ]
1313
fir.select %a : i32 [1, ^bb2(%1:i32), unit, ^bb3(%2:i32)]
1414
^bb2(%3 : i32) :
@@ -24,9 +24,9 @@ func.func @g(%a : i32) -> i32 {
2424
%1 = arith.constant 1 : i32
2525
%2 = arith.constant 42 : i32
2626

27-
// CHECK: switch i32 %{{.*}}, label %{{.*}} [
28-
// CHECK: i32 1, label %{{.*}}
29-
// CHECK: i32 -1, label %{{.*}}
27+
// CHECK: switch i64 %{{.*}}, label %{{.*}} [
28+
// CHECK: i64 1, label %{{.*}}
29+
// CHECK: i64 -1, label %{{.*}}
3030
// CHECK: ]
3131
fir.select_rank %a : i32 [1, ^bb2(%1:i32), -1, ^bb4, unit, ^bb3(%2:i32)]
3232
^bb2(%3 : i32) :

0 commit comments

Comments
 (0)