@@ -65,6 +65,8 @@ template <typename T> class Folder {
65
65
Expr<T> EOSHIFT (FunctionRef<T> &&);
66
66
Expr<T> PACK (FunctionRef<T> &&);
67
67
Expr<T> RESHAPE (FunctionRef<T> &&);
68
+ Expr<T> TRANSPOSE (FunctionRef<T> &&);
69
+ Expr<T> UNPACK (FunctionRef<T> &&);
68
70
69
71
private:
70
72
FoldingContext &context_;
@@ -853,6 +855,78 @@ template <typename T> Expr<T> Folder<T>::RESHAPE(FunctionRef<T> &&funcRef) {
853
855
return MakeInvalidIntrinsic (std::move (funcRef));
854
856
}
855
857
858
+ template <typename T> Expr<T> Folder<T>::TRANSPOSE(FunctionRef<T> &&funcRef) {
859
+ auto args{funcRef.arguments ()};
860
+ CHECK (args.size () == 1 );
861
+ const auto *matrix{UnwrapConstantValue<T>(args[0 ])};
862
+ if (!matrix) {
863
+ return Expr<T>{std::move (funcRef)};
864
+ }
865
+ // Argument is constant. Traverse its elements in transposed order.
866
+ std::vector<Scalar<T>> resultElements;
867
+ ConstantSubscripts at (2 );
868
+ for (ConstantSubscript j{0 }; j < matrix->shape ()[0 ]; ++j) {
869
+ at[0 ] = matrix->lbounds ()[0 ] + j;
870
+ for (ConstantSubscript k{0 }; k < matrix->shape ()[1 ]; ++k) {
871
+ at[1 ] = matrix->lbounds ()[1 ] + k;
872
+ resultElements.push_back (matrix->At (at));
873
+ }
874
+ }
875
+ at = matrix->shape ();
876
+ std::swap (at[0 ], at[1 ]);
877
+ return Expr<T>{PackageConstant<T>(std::move (resultElements), *matrix, at)};
878
+ }
879
+
880
+ template <typename T> Expr<T> Folder<T>::UNPACK(FunctionRef<T> &&funcRef) {
881
+ auto args{funcRef.arguments ()};
882
+ CHECK (args.size () == 3 );
883
+ const auto *vector{UnwrapConstantValue<T>(args[0 ])};
884
+ auto convertedMask{Fold (context_,
885
+ ConvertToType<LogicalResult>(
886
+ Expr<SomeLogical>{DEREF (UnwrapExpr<Expr<SomeLogical>>(args[1 ]))}))};
887
+ const auto *mask{UnwrapConstantValue<LogicalResult>(convertedMask)};
888
+ const auto *field{UnwrapConstantValue<T>(args[2 ])};
889
+ if (!vector || !mask || !field) {
890
+ return Expr<T>{std::move (funcRef)};
891
+ }
892
+ // Arguments are constant.
893
+ if (field->Rank () > 0 && field->shape () != mask->shape ()) {
894
+ // Error already emitted from intrinsic processing
895
+ return MakeInvalidIntrinsic (std::move (funcRef));
896
+ }
897
+ ConstantSubscript maskElements{GetSize (mask->shape ())};
898
+ ConstantSubscript truths{0 };
899
+ ConstantSubscripts maskAt{mask->lbounds ()};
900
+ for (ConstantSubscript j{0 }; j < maskElements;
901
+ ++j, mask->IncrementSubscripts (maskAt)) {
902
+ if (mask->At (maskAt).IsTrue ()) {
903
+ ++truths;
904
+ }
905
+ }
906
+ if (truths > GetSize (vector->shape ())) {
907
+ context_.messages ().Say (
908
+ " Invalid 'vector=' argument in UNPACK: the 'mask=' argument has %jd true elements, but the vector has only %jd elements" _err_en_US,
909
+ static_cast <std::intmax_t >(truths),
910
+ static_cast <std::intmax_t >(GetSize (vector->shape ())));
911
+ return MakeInvalidIntrinsic (std::move (funcRef));
912
+ }
913
+ std::vector<Scalar<T>> resultElements;
914
+ ConstantSubscripts vectorAt{vector->lbounds ()};
915
+ ConstantSubscripts fieldAt{field->lbounds ()};
916
+ for (ConstantSubscript j{0 }; j < maskElements; ++j) {
917
+ if (mask->At (maskAt).IsTrue ()) {
918
+ resultElements.push_back (vector->At (vectorAt));
919
+ vector->IncrementSubscripts (vectorAt);
920
+ } else {
921
+ resultElements.push_back (field->At (fieldAt));
922
+ }
923
+ mask->IncrementSubscripts (maskAt);
924
+ field->IncrementSubscripts (fieldAt);
925
+ }
926
+ return Expr<T>{
927
+ PackageConstant<T>(std::move (resultElements), *vector, mask->shape ())};
928
+ }
929
+
856
930
template <typename T>
857
931
Expr<T> FoldMINorMAX (
858
932
FoldingContext &context, FunctionRef<T> &&funcRef, Ordering order) {
@@ -943,8 +1017,12 @@ Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
943
1017
return Folder<T>{context}.PACK (std::move (funcRef));
944
1018
} else if (name == " reshape" ) {
945
1019
return Folder<T>{context}.RESHAPE (std::move (funcRef));
1020
+ } else if (name == " transpose" ) {
1021
+ return Folder<T>{context}.TRANSPOSE (std::move (funcRef));
1022
+ } else if (name == " unpack" ) {
1023
+ return Folder<T>{context}.UNPACK (std::move (funcRef));
946
1024
}
947
- // TODO: spread, unpack, transpose
1025
+ // TODO: spread
948
1026
// TODO: extends_type_of, same_type_as
949
1027
if constexpr (!std::is_same_v<T, SomeDerived>) {
950
1028
return FoldIntrinsicFunction (context, std::move (funcRef));
0 commit comments