Skip to content

Commit d2a5ad7

Browse files
committed
Improve tests by checking error, add a test for npz_write
1 parent 42dc45c commit d2a5ad7

File tree

5 files changed

+112
-72
lines changed

5 files changed

+112
-72
lines changed

src/stdlib_io_np.fypp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ module stdlib_io_np
7474
implicit none
7575
private
7676

77-
public :: load_npy, save_npy, load_npz, save_npz
77+
public :: load_npy, save_npy, load_npz, save_npz, add_array
7878

7979
character(len=*), parameter :: &
8080
type_iint8 = "<i1", type_iint16 = "<i2", type_iint32 = "<i4", type_iint64 = "<i8", &
@@ -135,13 +135,13 @@ module stdlib_io_np
135135
!> Save multidimensional arrays to a compressed or an uncompressed npz file.
136136
!> ([Specification](../page/specs/stdlib_io.html#save_npz))
137137
interface save_npz
138-
module subroutine save_npz_from_arrays(filename, arrays, compressed, iostat, iomsg)
138+
module subroutine save_npz_from_arrays(filename, arrays, iostat, iomsg, compressed)
139139
character(len=*), intent(in) :: filename
140140
type(t_array_wrapper), intent(in) :: arrays(:)
141-
!> If true, the file is saved in compressed format. The default is false.
142-
logical, intent(in), optional :: compressed
143141
integer, intent(out), optional :: iostat
144142
character(len=:), allocatable, intent(out), optional :: iomsg
143+
!> If true, the file is saved in compressed format. The default is false.
144+
logical, intent(in), optional :: compressed
145145
end
146146
end interface
147147

@@ -159,4 +159,9 @@ module stdlib_io_np
159159
#:endfor
160160
#:endfor
161161
end interface
162+
163+
contains
164+
165+
subroutine add_array()
166+
end
162167
end

src/stdlib_io_np_save.fypp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,13 @@ contains
139139
!>
140140
!> Save multidimensional arrays to a compressed or an uncompressed npz file.
141141
!> ([Specification](../page/specs/stdlib_io.html#save_npz))
142-
module subroutine save_npz_from_arrays(filename, arrays, compressed, iostat, iomsg)
142+
module subroutine save_npz_from_arrays(filename, arrays, iostat, iomsg, compressed)
143143
character(len=*), intent(in) :: filename
144144
type(t_array_wrapper), intent(in) :: arrays(:)
145-
!> If true, the file is saved in compressed format. The default is false.
146-
logical, intent(in), optional :: compressed
147145
integer, intent(out), optional :: iostat
148146
character(len=:), allocatable, intent(out), optional :: iomsg
147+
!> If true, the file is saved in compressed format. The default is false.
148+
logical, intent(in), optional :: compressed
149149

150150
integer :: i, j, stat
151151
logical :: is_compressed

src/stdlib_io_zip.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ subroutine zip(output_file, files, stat, msg)
2626

2727
if (trim(output_file) == '') then
2828
if (present(stat)) stat = 1
29-
if (present(msg)) msg = "Output file name is empty."
29+
if (present(msg)) msg = "Output file is empty."
3030
return
3131
end if
3232

test/io/test_filesystem.f90

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,12 @@ subroutine fs_list_dir_empty(error)
9797

9898
call run('rm -rf '//temp_list_dir, stat=stat)
9999
if (stat /= 0) then
100-
call test_failed(error, "Removing directory '"//temp_list_dir//"' failed.")
101-
return
100+
call test_failed(error, "Removing directory '"//temp_list_dir//"' failed."); return
102101
end if
103102

104103
call run('mkdir '//temp_list_dir, stat=stat)
105104
if (stat /= 0) then
106-
call test_failed(error, "Creating directory '"//temp_list_dir//"' failed.")
107-
return
105+
call test_failed(error, "Creating directory '"//temp_list_dir//"' failed."); return
108106
end if
109107

110108
call list_dir(temp_list_dir, files, stat)
@@ -124,20 +122,17 @@ subroutine fs_list_dir_one_file(error)
124122

125123
call run('rm -rf '//temp_list_dir, stat=stat)
126124
if (stat /= 0) then
127-
call test_failed(error, "Removing directory '"//temp_list_dir//"' failed.")
128-
return
125+
call test_failed(error, "Removing directory '"//temp_list_dir//"' failed."); return
129126
end if
130127

131128
call run('mkdir '//temp_list_dir, stat=stat)
132129
if (stat /= 0) then
133-
call test_failed(error, "Creating directory '"//temp_list_dir//"' failed.")
134-
return
130+
call test_failed(error, "Creating directory '"//temp_list_dir//"' failed."); return
135131
end if
136132

137133
call run('touch '//temp_list_dir//'/'//filename, stat=stat)
138134
if (stat /= 0) then
139-
call test_failed(error, "Creating file'"//filename//"' in directory '"//temp_list_dir//"' failed.")
140-
return
135+
call test_failed(error, "Creating file'"//filename//"' in directory '"//temp_list_dir//"' failed."); return
141136
end if
142137

143138
call list_dir(temp_list_dir, files, stat)
@@ -159,26 +154,22 @@ subroutine fs_list_dir_two_files(error)
159154

160155
call run('rm -rf '//temp_list_dir, stat=stat)
161156
if (stat /= 0) then
162-
call test_failed(error, "Removing directory '"//temp_list_dir//"' failed.")
163-
return
157+
call test_failed(error, "Removing directory '"//temp_list_dir//"' failed."); return
164158
end if
165159

166160
call run('mkdir '//temp_list_dir, stat=stat)
167161
if (stat /= 0) then
168-
call test_failed(error, "Creating directory '"//temp_list_dir//"' failed.")
169-
return
162+
call test_failed(error, "Creating directory '"//temp_list_dir//"' failed."); return
170163
end if
171164

172165
call run('touch '//temp_list_dir//'/'//filename1, stat=stat)
173166
if (stat /= 0) then
174-
call test_failed(error, "Creating file 1 in directory '"//temp_list_dir//"' failed.")
175-
return
167+
call test_failed(error, "Creating file 1 in directory '"//temp_list_dir//"' failed."); return
176168
end if
177169

178170
call run('touch '//temp_list_dir//'/'//filename2, stat=stat)
179171
if (stat /= 0) then
180-
call test_failed(error, "Creating file 2 in directory '"//temp_list_dir//"' failed.")
181-
return
172+
call test_failed(error, "Creating file 2 in directory '"//temp_list_dir//"' failed."); return
182173
end if
183174

184175
call list_dir(temp_list_dir, files, stat)

test/io/test_np.f90

Lines changed: 90 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module test_np
22
use stdlib_array
33
use stdlib_filesystem, only : temp_dir
44
use stdlib_kinds, only : int8, int16, int32, int64, sp, dp
5-
use stdlib_io_np, only : save_npy, load_npy, load_npz
5+
use stdlib_io_np, only : save_npy, load_npy, load_npz, save_npz
66
use testdrive, only : new_unittest, unittest_type, error_type, check, test_failed
77
implicit none
88
private
@@ -47,7 +47,9 @@ subroutine collect_np(testsuite)
4747
new_unittest("npz_load_arr_arange_10_20", npz_load_arr_arange_10_20), &
4848
new_unittest("npz_load_arr_cmplx", npz_load_arr_cmplx), &
4949
new_unittest("npz_load_two_arr_iint64_rdp", npz_load_two_arr_iint64_rdp), &
50-
new_unittest("npz_load_two_arr_iint64_rdp_comp", npz_load_two_arr_iint64_rdp_comp) &
50+
new_unittest("npz_load_two_arr_iint64_rdp_comp", npz_load_two_arr_iint64_rdp_comp), &
51+
new_unittest("npz_save_empty_array_input", npz_save_empty_array_input, should_fail=.true.), &
52+
new_unittest("npz_save_rdp_2", npz_save_rdp_2) &
5153
]
5254
end subroutine collect_np
5355

@@ -735,11 +737,11 @@ subroutine npz_load_arr_empty_0(error)
735737
path = get_path(filename)
736738
call load_npz(path, arrays, stat, tmp_dir=tmp)
737739
call check(error, stat, "Loading an npz that contains a single empty array shouldn't fail.")
738-
if (stat /= 0) return
740+
if (allocated(error)) return
739741
call check(error, size(arrays) == 1, "'"//filename//"' is supposed to contain a single array.")
740-
if (size(arrays) /= 1) return
742+
if (allocated(error)) return
741743
call check(error, arrays(1)%array%name == "arr_0.npy", "Wrong array name.")
742-
if (arrays(1)%array%name /= "arr_0.npy") return
744+
if (allocated(error)) return
743745
select type (typed_array => arrays(1)%array)
744746
class is (t_array_rdp_1)
745747
call check(error, size(typed_array%values) == 0, "Array in '"//filename//"' is supposed to be empty.")
@@ -760,11 +762,11 @@ subroutine npz_load_arr_rand_2_3(error)
760762
path = get_path(filename)
761763
call load_npz(path, arrays, stat, tmp_dir=tmp)
762764
call check(error, stat, "Loading an npz file that contains a valid nd_array shouldn't fail.")
763-
if (stat /= 0) return
765+
if (allocated(error)) return
764766
call check(error, size(arrays) == 1, "'"//filename//"' is supposed to contain a single array.")
765-
if (size(arrays) /= 1) return
767+
if (allocated(error)) return
766768
call check(error, arrays(1)%array%name == "arr_0.npy", "Wrong array name.")
767-
if (arrays(1)%array%name /= "arr_0.npy") return
769+
if (allocated(error)) return
768770
select type (typed_array => arrays(1)%array)
769771
class is (t_array_rdp_2)
770772
call check(error, size(typed_array%values) == 6, "Array in '"//filename//"' is supposed to have 6 entries.")
@@ -786,20 +788,20 @@ subroutine npz_load_arr_arange_10_20(error)
786788
path = get_path(filename)
787789
call load_npz(path, arrays, stat, tmp_dir=tmp)
788790
call check(error, stat, "Loading an npz file that contains a valid nd_array shouldn't fail.")
789-
if (stat /= 0) return
791+
if (allocated(error)) return
790792
call check(error, size(arrays) == 1, "'"//filename//"' is supposed to contain a single array.")
791-
if (size(arrays) /= 1) return
793+
if (allocated(error)) return
792794
call check(error, arrays(1)%array%name == "arr_0.npy", "Wrong array name.")
793-
if (arrays(1)%array%name /= "arr_0.npy") return
795+
if (allocated(error)) return
794796
select type (typed_array => arrays(1)%array)
795797
class is (t_array_iint64_1)
796798
call check(error, size(typed_array%values) == 10, "Array in '"//filename//"' is supposed to have 10 entries.")
797-
if (size(typed_array%values) /= 10) return
799+
if (allocated(error)) return
798800
call check(error, typed_array%values(1) == 10, "First entry is supposed to be 10.")
799-
if (typed_array%values(1) /= 10) return
801+
if (allocated(error)) return
800802
do i = 2, 10
801803
call check(error, typed_array%values(i) == typed_array%values(i-1) + 1, "Array is supposed to be an arange.")
802-
if (typed_array%values(i) /= typed_array%values(i-1) + 1) return
804+
if (allocated(error)) return
803805
end do
804806
class default
805807
call test_failed(error, "Array in '"//filename//"' is of wrong type.")
@@ -818,21 +820,21 @@ subroutine npz_load_arr_cmplx(error)
818820
path = get_path(filename)
819821
call load_npz(path, arrays, stat, tmp_dir=tmp)
820822
call check(error, stat, "Loading an npz file that contains a valid nd_array shouldn't fail.")
821-
if (stat /= 0) return
823+
if (allocated(error)) return
822824
call check(error, size(arrays) == 1, "'"//filename//"' is supposed to contain a single array.")
823-
if (size(arrays) /= 1) return
825+
if (allocated(error)) return
824826
call check(error, arrays(1)%array%name == "cmplx.npy", "Wrong array name.")
825-
if (arrays(1)%array%name /= "cmplx.npy") return
827+
if (allocated(error)) return
826828
select type (typed_array => arrays(1)%array)
827829
class is (t_array_csp_1)
828830
call check(error, size(typed_array%values) == 3, "Array in '"//filename//"' is supposed to have 3 entries.")
829-
if (size(typed_array%values) /= 3) return
831+
if (allocated(error)) return
830832
call check(error, typed_array%values(1) == cmplx(1_dp, 2_dp), "First complex number does not match.")
831-
if (typed_array%values(1) /= cmplx(1_dp, 2_dp)) return
833+
if (allocated(error)) return
832834
call check(error, typed_array%values(2) == cmplx(3_dp, 4_dp), "Second complex number does not match.")
833-
if (typed_array%values(2) /= cmplx(3_dp, 4_dp)) return
835+
if (allocated(error)) return
834836
call check(error, typed_array%values(3) == cmplx(5_dp, 6_dp), "Third complex number does not match.")
835-
if (typed_array%values(3) /= cmplx(5_dp, 6_dp)) return
837+
if (allocated(error)) return
836838
class default
837839
call test_failed(error, "Array in '"//filename//"' is of wrong type.")
838840
end select
@@ -850,36 +852,36 @@ subroutine npz_load_two_arr_iint64_rdp(error)
850852
path = get_path(filename)
851853
call load_npz(path, arrays, stat, tmp_dir=tmp)
852854
call check(error, stat, "Loading an npz file that contains valid nd_arrays shouldn't fail.")
853-
if (stat /= 0) return
855+
if (allocated(error)) return
854856
call check(error, size(arrays) == 2, "'"//filename//"' is supposed to contain two arrays.")
855-
if (size(arrays) /= 2) return
857+
if (allocated(error)) return
856858
call check(error, arrays(1)%array%name == "arr_0.npy", "Wrong array name.")
857-
if (arrays(1)%array%name /= "arr_0.npy") return
859+
if (allocated(error)) return
858860
call check(error, arrays(2)%array%name == "arr_1.npy", "Wrong array name.")
859-
if (arrays(2)%array%name /= "arr_1.npy") return
861+
if (allocated(error)) return
860862
select type (typed_array => arrays(1)%array)
861863
class is (t_array_iint64_1)
862864
call check(error, size(typed_array%values) == 3, "Array in '"//filename//"' is supposed to have 3 entries.")
863-
if (size(typed_array%values) /= 3) return
865+
if (allocated(error)) return
864866
call check(error, typed_array%values(1) == 1, "First integer does not match.")
865-
if (typed_array%values(1) /= 1) return
867+
if (allocated(error)) return
866868
call check(error, typed_array%values(2) == 2, "Second integer does not match.")
867-
if (typed_array%values(2) /= 2) return
869+
if (allocated(error)) return
868870
call check(error, typed_array%values(3) == 3, "Third integer does not match.")
869-
if (typed_array%values(3) /= 3) return
871+
if (allocated(error)) return
870872
class default
871873
call test_failed(error, "Array in '"//filename//"' is of wrong type.")
872874
end select
873875
select type (typed_array => arrays(2)%array)
874876
class is (t_array_rdp_1)
875877
call check(error, size(typed_array%values) == 3, "Array in '"//filename//"' is supposed to have 3 entries.")
876-
if (size(typed_array%values) /= 3) return
878+
if (allocated(error)) return
877879
call check(error, typed_array%values(1) == 1., "First number does not match.")
878-
if (typed_array%values(1) /= 1.) return
880+
if (allocated(error)) return
879881
call check(error, typed_array%values(2) == 1., "Second number does not match.")
880-
if (typed_array%values(2) /= 1.) return
882+
if (allocated(error)) return
881883
call check(error, typed_array%values(3) == 1., "Third number does not match.")
882-
if (typed_array%values(3) /= 1.) return
884+
if (allocated(error)) return
883885
class default
884886
call test_failed(error, "Array in '"//filename//"' is of wrong type.")
885887
end select
@@ -889,49 +891,91 @@ subroutine npz_load_two_arr_iint64_rdp_comp(error)
889891
type(error_type), allocatable, intent(out) :: error
890892

891893
type(t_array_wrapper), allocatable :: arrays(:)
892-
integer :: stat, i
894+
integer :: stat
893895
character(*), parameter :: filename = "two_arr_iint64_rdp_comp.npz"
894896
character(*), parameter :: tmp = temp_dir//"two_arr_iint64_rdp_comp"
895897
character(:), allocatable :: path
896898

897899
path = get_path(filename)
898900
call load_npz(path, arrays, stat, tmp_dir=tmp)
899901
call check(error, stat, "Loading a compressed npz file that contains valid nd_arrays shouldn't fail.")
900-
if (stat /= 0) return
902+
if (allocated(error)) return
901903
call check(error, size(arrays) == 2, "'"//filename//"' is supposed to contain two arrays.")
902-
if (size(arrays) /= 2) return
904+
if (allocated(error)) return
903905
call check(error, arrays(1)%array%name == "arr_0.npy", "Wrong array name.")
904-
if (arrays(1)%array%name /= "arr_0.npy") return
906+
if (allocated(error)) return
905907
call check(error, arrays(2)%array%name == "arr_1.npy", "Wrong array name.")
906-
if (arrays(2)%array%name /= "arr_1.npy") return
908+
if (allocated(error)) return
907909
select type (typed_array => arrays(1)%array)
908910
class is (t_array_iint64_1)
909911
call check(error, size(typed_array%values) == 3, "Array in '"//filename//"' is supposed to have 3 entries.")
910-
if (size(typed_array%values) /= 3) return
912+
if (allocated(error)) return
911913
call check(error, typed_array%values(1) == 1, "First integer does not match.")
912-
if (typed_array%values(1) /= 1) return
914+
if (allocated(error)) return
913915
call check(error, typed_array%values(2) == 2, "Second integer does not match.")
914-
if (typed_array%values(2) /= 2) return
916+
if (allocated(error)) return
915917
call check(error, typed_array%values(3) == 3, "Third integer does not match.")
916-
if (typed_array%values(3) /= 3) return
918+
if (allocated(error)) return
917919
class default
918920
call test_failed(error, "Array in '"//filename//"' is of wrong type.")
919921
end select
920922
select type (typed_array => arrays(2)%array)
921923
class is (t_array_rdp_1)
922924
call check(error, size(typed_array%values) == 3, "Array in '"//filename//"' is supposed to have 3 entries.")
923-
if (size(typed_array%values) /= 3) return
925+
if (allocated(error)) return
924926
call check(error, typed_array%values(1) == 1., "First number does not match.")
925-
if (typed_array%values(1) /= 1.) return
927+
if (allocated(error)) return
926928
call check(error, typed_array%values(2) == 1., "Second number does not match.")
927-
if (typed_array%values(2) /= 1.) return
929+
if (allocated(error)) return
928930
call check(error, typed_array%values(3) == 1., "Third number does not match.")
929-
if (typed_array%values(3) /= 1.) return
931+
if (allocated(error)) return
930932
class default
931933
call test_failed(error, "Array in '"//filename//"' is of wrong type.")
932934
end select
933935
end
934936

937+
subroutine npz_save_empty_array_input(error)
938+
type(error_type), allocatable, intent(out) :: error
939+
940+
type(t_array_wrapper), allocatable :: arrays(:)
941+
integer :: stat
942+
character(*), parameter :: filename = "output.npz"
943+
944+
allocate(arrays(0))
945+
call save_npz(filename, arrays, stat)
946+
call check(error, stat, "Trying to save an empty array fail.")
947+
end
948+
949+
subroutine npz_save_rdp_2(error)
950+
type(error_type), allocatable, intent(out) :: error
951+
952+
type(t_array_wrapper), allocatable :: arrays(:)
953+
integer :: stat
954+
character(*), parameter :: filename = "npz_save_rdp_2.npz"
955+
character(*), parameter :: arr_name = "arr_0.npy"
956+
real(dp), allocatable :: input(:,:), output(:,:)
957+
958+
allocate(input(10, 4))
959+
call random_number(input)
960+
! call add_array(arrays, input)
961+
962+
! call save_npz(filename, arrays, stat)
963+
! call check(error, stat, "Writing of npz file failed")
964+
! if (allocated(error)) return
965+
966+
! call load_npy(filename, output, stat)
967+
! call delete_file(filename)
968+
969+
! call check(error, stat, "Reading of npy file failed")
970+
! if (allocated(error)) return
971+
972+
! call check(error, size(output), size(input))
973+
! if (allocated(error)) return
974+
975+
! call check(error, any(abs(output - input) <= epsilon(1.0_dp)), &
976+
! "Precision loss when rereading array")
977+
end
978+
935979
!> Makes sure that we find the file when running both `ctest` and `fpm test`.
936980
function get_path(file) result(path)
937981
character(*), intent(in) :: file

0 commit comments

Comments
 (0)