Skip to content

Commit 4576068

Browse files
klauslerjeanPerier
authored andcommitted
[flang] Fold UNPACK and TRANSPOSE
Implement constant folding for the transformational intrinsic functions UNPACK and TRANSPOSE. Differential Revision: https://reviews.llvm.org/D109010
1 parent a44240c commit 4576068

File tree

4 files changed

+103
-2
lines changed

4 files changed

+103
-2
lines changed

flang/lib/Evaluate/fold-implementation.h

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ template <typename T> class Folder {
6565
Expr<T> EOSHIFT(FunctionRef<T> &&);
6666
Expr<T> PACK(FunctionRef<T> &&);
6767
Expr<T> RESHAPE(FunctionRef<T> &&);
68+
Expr<T> TRANSPOSE(FunctionRef<T> &&);
69+
Expr<T> UNPACK(FunctionRef<T> &&);
6870

6971
private:
7072
FoldingContext &context_;
@@ -853,6 +855,78 @@ template <typename T> Expr<T> Folder<T>::RESHAPE(FunctionRef<T> &&funcRef) {
853855
return MakeInvalidIntrinsic(std::move(funcRef));
854856
}
855857

858+
template <typename T> Expr<T> Folder<T>::TRANSPOSE(FunctionRef<T> &&funcRef) {
859+
auto args{funcRef.arguments()};
860+
CHECK(args.size() == 1);
861+
const auto *matrix{UnwrapConstantValue<T>(args[0])};
862+
if (!matrix) {
863+
return Expr<T>{std::move(funcRef)};
864+
}
865+
// Argument is constant. Traverse its elements in transposed order.
866+
std::vector<Scalar<T>> resultElements;
867+
ConstantSubscripts at(2);
868+
for (ConstantSubscript j{0}; j < matrix->shape()[0]; ++j) {
869+
at[0] = matrix->lbounds()[0] + j;
870+
for (ConstantSubscript k{0}; k < matrix->shape()[1]; ++k) {
871+
at[1] = matrix->lbounds()[1] + k;
872+
resultElements.push_back(matrix->At(at));
873+
}
874+
}
875+
at = matrix->shape();
876+
std::swap(at[0], at[1]);
877+
return Expr<T>{PackageConstant<T>(std::move(resultElements), *matrix, at)};
878+
}
879+
880+
template <typename T> Expr<T> Folder<T>::UNPACK(FunctionRef<T> &&funcRef) {
881+
auto args{funcRef.arguments()};
882+
CHECK(args.size() == 3);
883+
const auto *vector{UnwrapConstantValue<T>(args[0])};
884+
auto convertedMask{Fold(context_,
885+
ConvertToType<LogicalResult>(
886+
Expr<SomeLogical>{DEREF(UnwrapExpr<Expr<SomeLogical>>(args[1]))}))};
887+
const auto *mask{UnwrapConstantValue<LogicalResult>(convertedMask)};
888+
const auto *field{UnwrapConstantValue<T>(args[2])};
889+
if (!vector || !mask || !field) {
890+
return Expr<T>{std::move(funcRef)};
891+
}
892+
// Arguments are constant.
893+
if (field->Rank() > 0 && field->shape() != mask->shape()) {
894+
// Error already emitted from intrinsic processing
895+
return MakeInvalidIntrinsic(std::move(funcRef));
896+
}
897+
ConstantSubscript maskElements{GetSize(mask->shape())};
898+
ConstantSubscript truths{0};
899+
ConstantSubscripts maskAt{mask->lbounds()};
900+
for (ConstantSubscript j{0}; j < maskElements;
901+
++j, mask->IncrementSubscripts(maskAt)) {
902+
if (mask->At(maskAt).IsTrue()) {
903+
++truths;
904+
}
905+
}
906+
if (truths > GetSize(vector->shape())) {
907+
context_.messages().Say(
908+
"Invalid 'vector=' argument in UNPACK: the 'mask=' argument has %jd true elements, but the vector has only %jd elements"_err_en_US,
909+
static_cast<std::intmax_t>(truths),
910+
static_cast<std::intmax_t>(GetSize(vector->shape())));
911+
return MakeInvalidIntrinsic(std::move(funcRef));
912+
}
913+
std::vector<Scalar<T>> resultElements;
914+
ConstantSubscripts vectorAt{vector->lbounds()};
915+
ConstantSubscripts fieldAt{field->lbounds()};
916+
for (ConstantSubscript j{0}; j < maskElements; ++j) {
917+
if (mask->At(maskAt).IsTrue()) {
918+
resultElements.push_back(vector->At(vectorAt));
919+
vector->IncrementSubscripts(vectorAt);
920+
} else {
921+
resultElements.push_back(field->At(fieldAt));
922+
}
923+
mask->IncrementSubscripts(maskAt);
924+
field->IncrementSubscripts(fieldAt);
925+
}
926+
return Expr<T>{
927+
PackageConstant<T>(std::move(resultElements), *vector, mask->shape())};
928+
}
929+
856930
template <typename T>
857931
Expr<T> FoldMINorMAX(
858932
FoldingContext &context, FunctionRef<T> &&funcRef, Ordering order) {
@@ -943,8 +1017,12 @@ Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
9431017
return Folder<T>{context}.PACK(std::move(funcRef));
9441018
} else if (name == "reshape") {
9451019
return Folder<T>{context}.RESHAPE(std::move(funcRef));
1020+
} else if (name == "transpose") {
1021+
return Folder<T>{context}.TRANSPOSE(std::move(funcRef));
1022+
} else if (name == "unpack") {
1023+
return Folder<T>{context}.UNPACK(std::move(funcRef));
9461024
}
947-
// TODO: spread, unpack, transpose
1025+
// TODO: spread
9481026
// TODO: extends_type_of, same_type_as
9491027
if constexpr (!std::is_same_v<T, SomeDerived>) {
9501028
return FoldIntrinsicFunction(context, std::move(funcRef));

flang/test/Evaluate/folding19.f90

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,11 @@ subroutine s4
4343
!CHECK: error: Invalid 'vector=' argument in PACK: the 'mask=' argument has 3 true elements, but the vector has only 2 elements
4444
x = pack(array, mask, [0,0])
4545
end subroutine
46+
subroutine s5
47+
logical, parameter :: mask(2,3) = reshape([.false., .true., .true., .false., .false., .true.], shape(mask))
48+
integer, parameter :: field(3,2) = reshape([(-j,j=1,6)], shape(field))
49+
integer :: x(2,3)
50+
!CHECK: error: Invalid 'vector=' argument in UNPACK: the 'mask=' argument has 3 true elements, but the vector has only 2 elements
51+
x = unpack([1,2], mask, 0)
52+
end subroutine
4653
end module
47-

flang/test/Evaluate/folding25.f90

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
! RUN: %S/test_folding.sh %s %t %flang_fc1
2+
! REQUIRES: shell
3+
! Tests folding of UNPACK (valid cases)
4+
module m
5+
integer, parameter :: vector(*) = [1, 2, 3, 4]
6+
integer, parameter :: field(2,3) = reshape([(-j,j=1,6)], shape(field))
7+
logical, parameter :: mask(*,*) = reshape([.false., .true., .true., .false., .false., .true.], shape(field))
8+
logical, parameter :: test_unpack_1 = all(unpack(vector, mask, 0) == reshape([0,1,2,0,0,3], shape(mask)))
9+
logical, parameter :: test_unpack_2 = all(unpack(vector, mask, field) == reshape([-1,1,2,-4,-5,3], shape(mask)))
10+
end module

flang/test/Evaluate/folding26.f90

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
! RUN: %S/test_folding.sh %s %t %flang_fc1
2+
! REQUIRES: shell
3+
! Tests folding of TRANSPOSE
4+
module m
5+
integer, parameter :: matrix(0:1,0:2) = reshape([1,2,3,4,5,6],shape(matrix))
6+
logical, parameter :: test_transpose_1 = all(transpose(matrix) == reshape([1,3,5,2,4,6],[3,2]))
7+
end module

0 commit comments

Comments
 (0)