Skip to content

Commit cfa4946

Browse files
klauslerjeanPerier
authored andcommitted
[flang] Fold PACK()
Implement compile-time constant folding for the transformational intrinsic function PACK. Differential Revision: https://reviews.llvm.org/D108956
1 parent 67262fd commit cfa4946

File tree

3 files changed

+125
-6
lines changed

3 files changed

+125
-6
lines changed

flang/lib/Evaluate/fold-implementation.h

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ template <typename T> class Folder {
6363

6464
Expr<T> CSHIFT(FunctionRef<T> &&);
6565
Expr<T> EOSHIFT(FunctionRef<T> &&);
66+
Expr<T> PACK(FunctionRef<T> &&);
6667
Expr<T> RESHAPE(FunctionRef<T> &&);
6768

6869
private:
@@ -580,7 +581,7 @@ template <typename T> Expr<T> Folder<T>::CSHIFT(FunctionRef<T> &&funcRef) {
580581
if (j != zbDim) {
581582
if (array->shape()[j] != shift->shape()[k]) {
582583
context_.messages().Say(
583-
"Invalid 'shift=' argument in CSHIFT; extent on dimension %d is %jd but must be %jd"_err_en_US,
584+
"Invalid 'shift=' argument in CSHIFT: extent on dimension %d is %jd but must be %jd"_err_en_US,
584585
k + 1, static_cast<std::intmax_t>(shift->shape()[k]),
585586
static_cast<std::intmax_t>(array->shape()[j]));
586587
ok = false;
@@ -653,6 +654,9 @@ template <typename T> Expr<T> Folder<T>::EOSHIFT(FunctionRef<T> &&funcRef) {
653654
static_cast<std::intmax_t>(*dim));
654655
} else if (shift->Rank() > 0 && shift->Rank() != array->Rank() - 1) {
655656
// message already emitted from intrinsic look-up
657+
} else if (boundary && boundary->Rank() > 0 &&
658+
boundary->Rank() != array->Rank() - 1) {
659+
// ditto
656660
} else {
657661
int rank{array->Rank()};
658662
int zbDim{static_cast<int>(*dim) - 1};
@@ -663,15 +667,23 @@ template <typename T> Expr<T> Folder<T>::EOSHIFT(FunctionRef<T> &&funcRef) {
663667
if (j != zbDim) {
664668
if (array->shape()[j] != shift->shape()[k]) {
665669
context_.messages().Say(
666-
"Invalid 'shift=' argument in EOSHIFT; extent on dimension %d is %jd but must be %jd"_err_en_US,
670+
"Invalid 'shift=' argument in EOSHIFT: extent on dimension %d is %jd but must be %jd"_err_en_US,
667671
k + 1, static_cast<std::intmax_t>(shift->shape()[k]),
668672
static_cast<std::intmax_t>(array->shape()[j]));
669673
ok = false;
670674
}
671-
if (boundary && array->shape()[j] != boundary->shape()[k]) {
675+
++k;
676+
}
677+
}
678+
}
679+
if (boundary && boundary->Rank() > 0) {
680+
int k{0};
681+
for (int j{0}; j < rank; ++j) {
682+
if (j != zbDim) {
683+
if (array->shape()[j] != boundary->shape()[k]) {
672684
context_.messages().Say(
673-
"Invalid 'boundary=' argument in EOSHIFT; extent on dimension %d is %jd but must be %jd"_err_en_US,
674-
k + 1, static_cast<std::intmax_t>(shift->shape()[k]),
685+
"Invalid 'boundary=' argument in EOSHIFT: extent on dimension %d is %jd but must be %jd"_err_en_US,
686+
k + 1, static_cast<std::intmax_t>(boundary->shape()[k]),
675687
static_cast<std::intmax_t>(array->shape()[j]));
676688
ok = false;
677689
}
@@ -726,6 +738,70 @@ template <typename T> Expr<T> Folder<T>::EOSHIFT(FunctionRef<T> &&funcRef) {
726738
return MakeInvalidIntrinsic(std::move(funcRef));
727739
}
728740

741+
template <typename T> Expr<T> Folder<T>::PACK(FunctionRef<T> &&funcRef) {
742+
auto args{funcRef.arguments()};
743+
CHECK(args.size() == 3);
744+
const auto *array{UnwrapConstantValue<T>(args[0])};
745+
const auto *vector{UnwrapConstantValue<T>(args[2])};
746+
auto convertedMask{Fold(context_,
747+
ConvertToType<LogicalResult>(
748+
Expr<SomeLogical>{DEREF(UnwrapExpr<Expr<SomeLogical>>(args[1]))}))};
749+
const auto *mask{UnwrapConstantValue<LogicalResult>(convertedMask)};
750+
if (!array || !mask || (args[2] && !vector)) {
751+
return Expr<T>{std::move(funcRef)};
752+
}
753+
// Arguments are constant.
754+
ConstantSubscript arrayElements{GetSize(array->shape())};
755+
ConstantSubscript truths{0};
756+
ConstantSubscripts maskAt{mask->lbounds()};
757+
if (mask->Rank() == 0) {
758+
if (mask->At(maskAt).IsTrue()) {
759+
truths = arrayElements;
760+
}
761+
} else if (array->shape() != mask->shape()) {
762+
// Error already emitted from intrinsic processing
763+
return MakeInvalidIntrinsic(std::move(funcRef));
764+
} else {
765+
for (ConstantSubscript j{0}; j < arrayElements;
766+
++j, mask->IncrementSubscripts(maskAt)) {
767+
if (mask->At(maskAt).IsTrue()) {
768+
++truths;
769+
}
770+
}
771+
}
772+
std::vector<Scalar<T>> resultElements;
773+
ConstantSubscripts arrayAt{array->lbounds()};
774+
ConstantSubscript resultSize{truths};
775+
if (vector) {
776+
resultSize = vector->shape().at(0);
777+
if (resultSize < truths) {
778+
context_.messages().Say(
779+
"Invalid 'vector=' argument in PACK: the 'mask=' argument has %jd true elements, but the vector has only %jd elements"_err_en_US,
780+
static_cast<std::intmax_t>(truths),
781+
static_cast<std::intmax_t>(resultSize));
782+
return MakeInvalidIntrinsic(std::move(funcRef));
783+
}
784+
}
785+
for (ConstantSubscript j{0}; j < truths;) {
786+
if (mask->At(maskAt).IsTrue()) {
787+
resultElements.push_back(array->At(arrayAt));
788+
++j;
789+
}
790+
array->IncrementSubscripts(arrayAt);
791+
mask->IncrementSubscripts(maskAt);
792+
}
793+
if (vector) {
794+
ConstantSubscripts vectorAt{vector->lbounds()};
795+
vectorAt.at(0) += truths;
796+
for (ConstantSubscript j{truths}; j < resultSize; ++j) {
797+
resultElements.push_back(vector->At(vectorAt));
798+
++vectorAt[0];
799+
}
800+
}
801+
return Expr<T>{PackageConstant<T>(std::move(resultElements), *array,
802+
ConstantSubscripts{static_cast<ConstantSubscript>(resultSize)})};
803+
}
804+
729805
template <typename T> Expr<T> Folder<T>::RESHAPE(FunctionRef<T> &&funcRef) {
730806
auto args{funcRef.arguments()};
731807
CHECK(args.size() == 4);
@@ -863,10 +939,12 @@ Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
863939
return Folder<T>{context}.CSHIFT(std::move(funcRef));
864940
} else if (name == "eoshift") {
865941
return Folder<T>{context}.EOSHIFT(std::move(funcRef));
942+
} else if (name == "pack") {
943+
return Folder<T>{context}.PACK(std::move(funcRef));
866944
} else if (name == "reshape") {
867945
return Folder<T>{context}.RESHAPE(std::move(funcRef));
868946
}
869-
// TODO: eoshift, pack, spread, unpack, transpose
947+
// TODO: spread, unpack, transpose
870948
// TODO: extends_type_of, same_type_as
871949
if constexpr (!std::is_same_v<T, SomeDerived>) {
872950
return FoldIntrinsicFunction(context, std::move(funcRef));

flang/test/Evaluate/folding19.f90

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,30 @@ subroutine s1(a,b)
1818
!CHECK: error: DIM=2 dimension is out of range for rank-1 array
1919
integer :: lb3(lbound(b,2))
2020
end subroutine
21+
subroutine s2
22+
integer, parameter :: array(2,3) = reshape([(j, j=1, 6)], shape(array))
23+
integer :: x(2, 3)
24+
!CHECK: error: Invalid 'dim=' argument (0) in CSHIFT
25+
x = cshift(array, [1, 2], dim=0)
26+
!CHECK: error: Invalid 'shift=' argument in CSHIFT: extent on dimension 1 is 2 but must be 3
27+
x = cshift(array, [1, 2], dim=1)
28+
end subroutine
29+
subroutine s3
30+
integer, parameter :: array(2,3) = reshape([(j, j=1, 6)], shape(array))
31+
integer :: x(2, 3)
32+
!CHECK: error: Invalid 'dim=' argument (0) in EOSHIFT
33+
x = eoshift(array, [1, 2], dim=0)
34+
!CHECK: error: Invalid 'shift=' argument in EOSHIFT: extent on dimension 1 is 2 but must be 3
35+
x = eoshift(array, [1, 2], dim=1)
36+
!CHECK: error: Invalid 'boundary=' argument in EOSHIFT: extent on dimension 1 is 3 but must be 2
37+
x = eoshift(array, 1, [0, 0, 0], 2)
38+
end subroutine
39+
subroutine s4
40+
integer, parameter :: array(2,3) = reshape([(j, j=1, 6)], shape(array))
41+
logical, parameter :: mask(*,*) = reshape([(.true., j=1,3),(.false., j=1,3)], shape(array))
42+
integer :: x(3)
43+
!CHECK: error: Invalid 'vector=' argument in PACK: the 'mask=' argument has 3 true elements, but the vector has only 2 elements
44+
x = pack(array, mask, [0,0])
45+
end subroutine
2146
end module
2247

flang/test/Evaluate/folding24.f90

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
! RUN: %S/test_folding.sh %s %t %flang_fc1
2+
! REQUIRES: shell
3+
! Tests folding of PACK (valid cases)
4+
module m
5+
integer, parameter :: arr(2,3) = reshape([1, 2, 3, 4, 5, 6], shape(arr))
6+
logical, parameter :: odds(*,*) = modulo(arr, 2) /= 0
7+
integer, parameter :: vect(*) = [(j, j=-10, -1)]
8+
logical, parameter :: test_pack_1 = all(pack(arr, .true.) == [arr])
9+
logical, parameter :: test_pack_2 = all(pack(arr, .false.) == [integer::])
10+
logical, parameter :: test_pack_3 = all(pack(arr, odds) == [1, 3, 5])
11+
logical, parameter :: test_pack_4 = all(pack(arr, .not. odds) == [2, 4, 6])
12+
logical, parameter :: test_pack_5 = all(pack(arr, .true., vect) == [1, 2, 3, 4, 5, 6, -4, -3, -2, -1])
13+
logical, parameter :: test_pack_6 = all(pack(arr, .false., vect) == vect)
14+
logical, parameter :: test_pack_7 = all(pack(arr, odds, vect) == [1, 3, 5, -7, -6, -5, -4, -3, -2, -1])
15+
logical, parameter :: test_pack_8 = all(pack(arr, .not. odds, vect) == [2, 4, 6, -7, -6, -5, -4, -3, -2, -1])
16+
end module

0 commit comments

Comments
 (0)