Skip to content

Commit dc2267c

Browse files
klauslerjeanPerier
authored andcommitted
[flang] Fix bogus folding error for ISHFT(x, negative)
Negative shift counts are of course valid for ISHFT when shifting to the right. This patch decouples the folding of ISHFT from that of SHIFTA/L/R and adds tests. Differential Revision: https://reviews.llvm.org/D112244
1 parent e18d718 commit dc2267c

File tree

3 files changed

+70
-21
lines changed

3 files changed

+70
-21
lines changed

flang/lib/Evaluate/fold-integer.cpp

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -610,34 +610,21 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
610610
} else if (name == "iparity") {
611611
return FoldBitReduction(
612612
context, std::move(funcRef), &Scalar<T>::IEOR, Scalar<T>{});
613-
} else if (name == "ishft" || name == "shifta" || name == "shiftr" ||
614-
name == "shiftl") {
615-
// Second argument can be of any kind. However, it must be smaller or
616-
// equal than BIT_SIZE. It can be converted to Int4 to simplify.
617-
auto fptr{&Scalar<T>::ISHFT};
618-
if (name == "ishft") { // done in fptr definition
619-
} else if (name == "shifta") {
620-
fptr = &Scalar<T>::SHIFTA;
621-
} else if (name == "shiftr") {
622-
fptr = &Scalar<T>::SHIFTR;
623-
} else if (name == "shiftl") {
624-
fptr = &Scalar<T>::SHIFTL;
625-
} else {
626-
common::die("missing case to fold intrinsic function %s", name.c_str());
627-
}
613+
} else if (name == "ishft") {
628614
return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
629615
ScalarFunc<T, T, Int4>([&](const Scalar<T> &i,
630616
const Scalar<Int4> &pos) -> Scalar<T> {
631617
auto posVal{static_cast<int>(pos.ToInt64())};
632-
if (posVal < 0) {
618+
if (posVal < -i.bits) {
633619
context.messages().Say(
634-
"shift count for %s (%d) is negative"_err_en_US, name, posVal);
620+
"SHIFT=%d count for ishft is less than %d"_err_en_US, posVal,
621+
-i.bits);
635622
} else if (posVal > i.bits) {
636623
context.messages().Say(
637-
"shift count for %s (%d) is greater than %d"_err_en_US, name,
638-
posVal, i.bits);
624+
"SHIFT=%d count for ishft is greater than %d"_err_en_US, posVal,
625+
i.bits);
639626
}
640-
return std::invoke(fptr, i, posVal);
627+
return i.ISHFT(posVal);
641628
}));
642629
} else if (name == "lbound") {
643630
return LBOUND(context, std::move(funcRef));
@@ -856,6 +843,32 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
856843
return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
857844
}
858845
}
846+
} else if (name == "shifta" || name == "shiftr" || name == "shiftl") {
847+
// Second argument can be of any kind. However, it must be smaller or
848+
// equal than BIT_SIZE. It can be converted to Int4 to simplify.
849+
auto fptr{&Scalar<T>::SHIFTA};
850+
if (name == "shifta") { // done in fptr definition
851+
} else if (name == "shiftr") {
852+
fptr = &Scalar<T>::SHIFTR;
853+
} else if (name == "shiftl") {
854+
fptr = &Scalar<T>::SHIFTL;
855+
} else {
856+
common::die("missing case to fold intrinsic function %s", name.c_str());
857+
}
858+
return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
859+
ScalarFunc<T, T, Int4>([&](const Scalar<T> &i,
860+
const Scalar<Int4> &pos) -> Scalar<T> {
861+
auto posVal{static_cast<int>(pos.ToInt64())};
862+
if (posVal < 0) {
863+
context.messages().Say(
864+
"SHIFT=%d count for %s is negative"_err_en_US, posVal, name);
865+
} else if (posVal > i.bits) {
866+
context.messages().Say(
867+
"SHIFT=%d count for %s is greater than %d"_err_en_US, posVal,
868+
name, i.bits);
869+
}
870+
return std::invoke(fptr, i, posVal);
871+
}));
859872
} else if (name == "sign") {
860873
return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
861874
ScalarFunc<T, T, T>(

flang/test/Evaluate/fold-ishft.f90

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
! RUN: %python %S/test_folding.py %s %flang_fc1
2+
! Tests folding of ISHFT
3+
module m1
4+
logical :: test_ishft_lsb = all(ishft(1, [-32, -31, -1, 0, 1, 2, 31, 32]) == [0, 0, 0, 1, 2, 4, int(z'80000000'), 0])
5+
logical :: test_ishft_msb = all(ishft(ishft(1,31), [-32, -31, -1, 0, 1, 2, 31, 32]) == [0, 1, int(z'40000000'), int(z'80000000'), 0, 0, 0, 0])
6+
end module

flang/test/Evaluate/folding19.f90

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,38 @@ subroutine s6
5656
!CHECK: error: POS=32 out of range for BTEST
5757
logical, parameter :: bad2 = btest(0, 32)
5858
!CHECK-NOT: error: POS=33 out of range for BTEST
59-
logical, parameter :: bad3 = btest(0_8, 33)
59+
logical, parameter :: ok1 = btest(0_8, 33)
6060
!CHECK: error: POS=64 out of range for BTEST
6161
logical, parameter :: bad4 = btest(0_8, 64)
6262
end subroutine
63+
subroutine s7
64+
!CHECK: error: SHIFT=-33 count for ishft is less than -32
65+
integer, parameter :: bad1 = ishft(1, -33)
66+
integer, parameter :: ok1 = ishft(1, -32)
67+
integer, parameter :: ok2 = ishft(1, 32)
68+
!CHECK: error: SHIFT=33 count for ishft is greater than 32
69+
integer, parameter :: bad2 = ishft(1, 33)
70+
!CHECK: error: SHIFT=-65 count for ishft is less than -64
71+
integer(8), parameter :: bad3 = ishft(1_8, -65)
72+
integer(8), parameter :: ok3 = ishft(1_8, -64)
73+
integer(8), parameter :: ok4 = ishft(1_8, 64)
74+
!CHECK: error: SHIFT=65 count for ishft is greater than 64
75+
integer(8), parameter :: bad4 = ishft(1_8, 65)
76+
end subroutine
77+
subroutine s8
78+
!CHECK: error: SHIFT=-33 count for shiftl is negative
79+
integer, parameter :: bad1 = shiftl(1, -33)
80+
!CHECK: error: SHIFT=-32 count for shiftl is negative
81+
integer, parameter :: bad2 = shiftl(1, -32)
82+
integer, parameter :: ok1 = shiftl(1, 32)
83+
!CHECK: error: SHIFT=33 count for shiftl is greater than 32
84+
integer, parameter :: bad3 = shiftl(1, 33)
85+
!CHECK: error: SHIFT=-65 count for shiftl is negative
86+
integer(8), parameter :: bad4 = shiftl(1_8, -65)
87+
!CHECK: error: SHIFT=-64 count for shiftl is negative
88+
integer(8), parameter :: bad5 = shiftl(1_8, -64)
89+
integer(8), parameter :: ok2 = shiftl(1_8, 64)
90+
!CHECK: error: SHIFT=65 count for shiftl is greater than 64
91+
integer(8), parameter :: bad6 = shiftl(1_8, 65)
92+
end subroutine
6393
end module

0 commit comments

Comments
 (0)