@@ -2,7 +2,7 @@ module test_np
2
2
use stdlib_array
3
3
use stdlib_filesystem, only : temp_dir
4
4
use stdlib_kinds, only : int8, int16, int32, int64, sp, dp
5
- use stdlib_io_np, only : save_npy, load_npy, load_npz
5
+ use stdlib_io_np, only : save_npy, load_npy, load_npz, save_npz
6
6
use testdrive, only : new_unittest, unittest_type, error_type, check, test_failed
7
7
implicit none
8
8
private
@@ -47,7 +47,9 @@ subroutine collect_np(testsuite)
47
47
new_unittest(" npz_load_arr_arange_10_20" , npz_load_arr_arange_10_20), &
48
48
new_unittest(" npz_load_arr_cmplx" , npz_load_arr_cmplx), &
49
49
new_unittest(" npz_load_two_arr_iint64_rdp" , npz_load_two_arr_iint64_rdp), &
50
- new_unittest(" npz_load_two_arr_iint64_rdp_comp" , npz_load_two_arr_iint64_rdp_comp) &
50
+ new_unittest(" npz_load_two_arr_iint64_rdp_comp" , npz_load_two_arr_iint64_rdp_comp), &
51
+ new_unittest(" npz_save_empty_array_input" , npz_save_empty_array_input, should_fail= .true. ), &
52
+ new_unittest(" npz_save_rdp_2" , npz_save_rdp_2) &
51
53
]
52
54
end subroutine collect_np
53
55
@@ -735,11 +737,11 @@ subroutine npz_load_arr_empty_0(error)
735
737
path = get_path(filename)
736
738
call load_npz(path, arrays, stat, tmp_dir= tmp)
737
739
call check(error, stat, " Loading an npz that contains a single empty array shouldn't fail." )
738
- if (stat /= 0 ) return
740
+ if (allocated (error) ) return
739
741
call check(error, size (arrays) == 1 , " '" // filename// " ' is supposed to contain a single array." )
740
- if (size (arrays) /= 1 ) return
742
+ if (allocated (error) ) return
741
743
call check(error, arrays(1 )% array% name == " arr_0.npy" , " Wrong array name." )
742
- if (arrays( 1 ) % array % name /= " arr_0.npy " ) return
744
+ if (allocated (error) ) return
743
745
select type (typed_array = > arrays(1 )% array)
744
746
class is (t_array_rdp_1)
745
747
call check(error, size (typed_array% values) == 0 , " Array in '" // filename// " ' is supposed to be empty." )
@@ -760,11 +762,11 @@ subroutine npz_load_arr_rand_2_3(error)
760
762
path = get_path(filename)
761
763
call load_npz(path, arrays, stat, tmp_dir= tmp)
762
764
call check(error, stat, " Loading an npz file that contains a valid nd_array shouldn't fail." )
763
- if (stat /= 0 ) return
765
+ if (allocated (error) ) return
764
766
call check(error, size (arrays) == 1 , " '" // filename// " ' is supposed to contain a single array." )
765
- if (size (arrays) /= 1 ) return
767
+ if (allocated (error) ) return
766
768
call check(error, arrays(1 )% array% name == " arr_0.npy" , " Wrong array name." )
767
- if (arrays( 1 ) % array % name /= " arr_0.npy " ) return
769
+ if (allocated (error) ) return
768
770
select type (typed_array = > arrays(1 )% array)
769
771
class is (t_array_rdp_2)
770
772
call check(error, size (typed_array% values) == 6 , " Array in '" // filename// " ' is supposed to have 6 entries." )
@@ -786,20 +788,20 @@ subroutine npz_load_arr_arange_10_20(error)
786
788
path = get_path(filename)
787
789
call load_npz(path, arrays, stat, tmp_dir= tmp)
788
790
call check(error, stat, " Loading an npz file that contains a valid nd_array shouldn't fail." )
789
- if (stat /= 0 ) return
791
+ if (allocated (error) ) return
790
792
call check(error, size (arrays) == 1 , " '" // filename// " ' is supposed to contain a single array." )
791
- if (size (arrays) /= 1 ) return
793
+ if (allocated (error) ) return
792
794
call check(error, arrays(1 )% array% name == " arr_0.npy" , " Wrong array name." )
793
- if (arrays( 1 ) % array % name /= " arr_0.npy " ) return
795
+ if (allocated (error) ) return
794
796
select type (typed_array = > arrays(1 )% array)
795
797
class is (t_array_iint64_1)
796
798
call check(error, size (typed_array% values) == 10 , " Array in '" // filename// " ' is supposed to have 10 entries." )
797
- if (size (typed_array % values) /= 10 ) return
799
+ if (allocated (error) ) return
798
800
call check(error, typed_array% values(1 ) == 10 , " First entry is supposed to be 10." )
799
- if (typed_array % values( 1 ) /= 10 ) return
801
+ if (allocated (error) ) return
800
802
do i = 2 , 10
801
803
call check(error, typed_array% values(i) == typed_array% values(i-1 ) + 1 , " Array is supposed to be an arange." )
802
- if (typed_array % values(i) /= typed_array % values(i -1 ) + 1 ) return
804
+ if (allocated (error) ) return
803
805
end do
804
806
class default
805
807
call test_failed(error, " Array in '" // filename// " ' is of wrong type." )
@@ -818,21 +820,21 @@ subroutine npz_load_arr_cmplx(error)
818
820
path = get_path(filename)
819
821
call load_npz(path, arrays, stat, tmp_dir= tmp)
820
822
call check(error, stat, " Loading an npz file that contains a valid nd_array shouldn't fail." )
821
- if (stat /= 0 ) return
823
+ if (allocated (error) ) return
822
824
call check(error, size (arrays) == 1 , " '" // filename// " ' is supposed to contain a single array." )
823
- if (size (arrays) /= 1 ) return
825
+ if (allocated (error) ) return
824
826
call check(error, arrays(1 )% array% name == " cmplx.npy" , " Wrong array name." )
825
- if (arrays( 1 ) % array % name /= " cmplx.npy " ) return
827
+ if (allocated (error) ) return
826
828
select type (typed_array = > arrays(1 )% array)
827
829
class is (t_array_csp_1)
828
830
call check(error, size (typed_array% values) == 3 , " Array in '" // filename// " ' is supposed to have 3 entries." )
829
- if (size (typed_array % values) /= 3 ) return
831
+ if (allocated (error) ) return
830
832
call check(error, typed_array% values(1 ) == cmplx (1_dp , 2_dp ), " First complex number does not match." )
831
- if (typed_array % values( 1 ) /= cmplx ( 1_dp , 2_dp )) return
833
+ if (allocated (error )) return
832
834
call check(error, typed_array% values(2 ) == cmplx (3_dp , 4_dp ), " Second complex number does not match." )
833
- if (typed_array % values( 2 ) /= cmplx ( 3_dp , 4_dp )) return
835
+ if (allocated (error )) return
834
836
call check(error, typed_array% values(3 ) == cmplx (5_dp , 6_dp ), " Third complex number does not match." )
835
- if (typed_array % values( 3 ) /= cmplx ( 5_dp , 6_dp )) return
837
+ if (allocated (error )) return
836
838
class default
837
839
call test_failed(error, " Array in '" // filename// " ' is of wrong type." )
838
840
end select
@@ -850,36 +852,36 @@ subroutine npz_load_two_arr_iint64_rdp(error)
850
852
path = get_path(filename)
851
853
call load_npz(path, arrays, stat, tmp_dir= tmp)
852
854
call check(error, stat, " Loading an npz file that contains valid nd_arrays shouldn't fail." )
853
- if (stat /= 0 ) return
855
+ if (allocated (error) ) return
854
856
call check(error, size (arrays) == 2 , " '" // filename// " ' is supposed to contain two arrays." )
855
- if (size (arrays) /= 2 ) return
857
+ if (allocated (error) ) return
856
858
call check(error, arrays(1 )% array% name == " arr_0.npy" , " Wrong array name." )
857
- if (arrays( 1 ) % array % name /= " arr_0.npy " ) return
859
+ if (allocated (error) ) return
858
860
call check(error, arrays(2 )% array% name == " arr_1.npy" , " Wrong array name." )
859
- if (arrays( 2 ) % array % name /= " arr_1.npy " ) return
861
+ if (allocated (error) ) return
860
862
select type (typed_array = > arrays(1 )% array)
861
863
class is (t_array_iint64_1)
862
864
call check(error, size (typed_array% values) == 3 , " Array in '" // filename// " ' is supposed to have 3 entries." )
863
- if (size (typed_array % values) /= 3 ) return
865
+ if (allocated (error) ) return
864
866
call check(error, typed_array% values(1 ) == 1 , " First integer does not match." )
865
- if (typed_array % values( 1 ) /= 1 ) return
867
+ if (allocated (error) ) return
866
868
call check(error, typed_array% values(2 ) == 2 , " Second integer does not match." )
867
- if (typed_array % values( 2 ) /= 2 ) return
869
+ if (allocated (error) ) return
868
870
call check(error, typed_array% values(3 ) == 3 , " Third integer does not match." )
869
- if (typed_array % values( 3 ) /= 3 ) return
871
+ if (allocated (error) ) return
870
872
class default
871
873
call test_failed(error, " Array in '" // filename// " ' is of wrong type." )
872
874
end select
873
875
select type (typed_array = > arrays(2 )% array)
874
876
class is (t_array_rdp_1)
875
877
call check(error, size (typed_array% values) == 3 , " Array in '" // filename// " ' is supposed to have 3 entries." )
876
- if (size (typed_array % values) /= 3 ) return
878
+ if (allocated (error) ) return
877
879
call check(error, typed_array% values(1 ) == 1 ., " First number does not match." )
878
- if (typed_array % values( 1 ) /= 1 . ) return
880
+ if (allocated (error) ) return
879
881
call check(error, typed_array% values(2 ) == 1 ., " Second number does not match." )
880
- if (typed_array % values( 2 ) /= 1 . ) return
882
+ if (allocated (error) ) return
881
883
call check(error, typed_array% values(3 ) == 1 ., " Third number does not match." )
882
- if (typed_array % values( 3 ) /= 1 . ) return
884
+ if (allocated (error) ) return
883
885
class default
884
886
call test_failed(error, " Array in '" // filename// " ' is of wrong type." )
885
887
end select
@@ -889,49 +891,91 @@ subroutine npz_load_two_arr_iint64_rdp_comp(error)
889
891
type (error_type), allocatable , intent (out ) :: error
890
892
891
893
type (t_array_wrapper), allocatable :: arrays(:)
892
- integer :: stat, i
894
+ integer :: stat
893
895
character (* ), parameter :: filename = " two_arr_iint64_rdp_comp.npz"
894
896
character (* ), parameter :: tmp = temp_dir// " two_arr_iint64_rdp_comp"
895
897
character (:), allocatable :: path
896
898
897
899
path = get_path(filename)
898
900
call load_npz(path, arrays, stat, tmp_dir= tmp)
899
901
call check(error, stat, " Loading a compressed npz file that contains valid nd_arrays shouldn't fail." )
900
- if (stat /= 0 ) return
902
+ if (allocated (error) ) return
901
903
call check(error, size (arrays) == 2 , " '" // filename// " ' is supposed to contain two arrays." )
902
- if (size (arrays) /= 2 ) return
904
+ if (allocated (error) ) return
903
905
call check(error, arrays(1 )% array% name == " arr_0.npy" , " Wrong array name." )
904
- if (arrays( 1 ) % array % name /= " arr_0.npy " ) return
906
+ if (allocated (error) ) return
905
907
call check(error, arrays(2 )% array% name == " arr_1.npy" , " Wrong array name." )
906
- if (arrays( 2 ) % array % name /= " arr_1.npy " ) return
908
+ if (allocated (error) ) return
907
909
select type (typed_array = > arrays(1 )% array)
908
910
class is (t_array_iint64_1)
909
911
call check(error, size (typed_array% values) == 3 , " Array in '" // filename// " ' is supposed to have 3 entries." )
910
- if (size (typed_array % values) /= 3 ) return
912
+ if (allocated (error) ) return
911
913
call check(error, typed_array% values(1 ) == 1 , " First integer does not match." )
912
- if (typed_array % values( 1 ) /= 1 ) return
914
+ if (allocated (error) ) return
913
915
call check(error, typed_array% values(2 ) == 2 , " Second integer does not match." )
914
- if (typed_array % values( 2 ) /= 2 ) return
916
+ if (allocated (error) ) return
915
917
call check(error, typed_array% values(3 ) == 3 , " Third integer does not match." )
916
- if (typed_array % values( 3 ) /= 3 ) return
918
+ if (allocated (error) ) return
917
919
class default
918
920
call test_failed(error, " Array in '" // filename// " ' is of wrong type." )
919
921
end select
920
922
select type (typed_array = > arrays(2 )% array)
921
923
class is (t_array_rdp_1)
922
924
call check(error, size (typed_array% values) == 3 , " Array in '" // filename// " ' is supposed to have 3 entries." )
923
- if (size (typed_array % values) /= 3 ) return
925
+ if (allocated (error) ) return
924
926
call check(error, typed_array% values(1 ) == 1 ., " First number does not match." )
925
- if (typed_array % values( 1 ) /= 1 . ) return
927
+ if (allocated (error) ) return
926
928
call check(error, typed_array% values(2 ) == 1 ., " Second number does not match." )
927
- if (typed_array % values( 2 ) /= 1 . ) return
929
+ if (allocated (error) ) return
928
930
call check(error, typed_array% values(3 ) == 1 ., " Third number does not match." )
929
- if (typed_array % values( 3 ) /= 1 . ) return
931
+ if (allocated (error) ) return
930
932
class default
931
933
call test_failed(error, " Array in '" // filename// " ' is of wrong type." )
932
934
end select
933
935
end
934
936
937
+ subroutine npz_save_empty_array_input (error )
938
+ type (error_type), allocatable , intent (out ) :: error
939
+
940
+ type (t_array_wrapper), allocatable :: arrays(:)
941
+ integer :: stat
942
+ character (* ), parameter :: filename = " output.npz"
943
+
944
+ allocate (arrays(0 ))
945
+ call save_npz(filename, arrays, stat)
946
+ call check(error, stat, " Trying to save an empty array fail." )
947
+ end
948
+
949
+ subroutine npz_save_rdp_2 (error )
950
+ type (error_type), allocatable , intent (out ) :: error
951
+
952
+ type (t_array_wrapper), allocatable :: arrays(:)
953
+ integer :: stat
954
+ character (* ), parameter :: filename = " npz_save_rdp_2.npz"
955
+ character (* ), parameter :: arr_name = " arr_0.npy"
956
+ real (dp), allocatable :: input(:,:), output(:,:)
957
+
958
+ allocate (input(10 , 4 ))
959
+ call random_number (input)
960
+ ! call add_array(arrays, input)
961
+
962
+ ! call save_npz(filename, arrays, stat)
963
+ ! call check(error, stat, "Writing of npz file failed")
964
+ ! if (allocated(error)) return
965
+
966
+ ! call load_npy(filename, output, stat)
967
+ ! call delete_file(filename)
968
+
969
+ ! call check(error, stat, "Reading of npy file failed")
970
+ ! if (allocated(error)) return
971
+
972
+ ! call check(error, size(output), size(input))
973
+ ! if (allocated(error)) return
974
+
975
+ ! call check(error, any(abs(output - input) <= epsilon(1.0_dp)), &
976
+ ! "Precision loss when rereading array")
977
+ end
978
+
935
979
! > Makes sure that we find the file when running both `ctest` and `fpm test`.
936
980
function get_path (file ) result(path)
937
981
character (* ), intent (in ) :: file
0 commit comments