1
1
module test_np
2
2
use stdlib_array
3
- use stdlib_filesystem, only : temp_dir
3
+ use stdlib_filesystem, only : temp_dir, exists
4
4
use stdlib_kinds, only : int8, int16, int32, int64, sp, dp
5
5
use stdlib_io_np, only : save_npy, load_npy, load_npz, add_array, save_npz
6
6
use testdrive, only : new_unittest, unittest_type, error_type, check, test_failed
@@ -48,12 +48,14 @@ subroutine collect_np(testsuite)
48
48
new_unittest(" npz_load_arr_cmplx" , npz_load_arr_cmplx), &
49
49
new_unittest(" npz_load_two_arr_iint64_rdp" , npz_load_two_arr_iint64_rdp), &
50
50
new_unittest(" npz_load_two_arr_iint64_rdp_comp" , npz_load_two_arr_iint64_rdp_comp), &
51
- new_unittest(" npz_add_arr_to_empty" , npz_add_arr_to_empty), &
52
- new_unittest(" npz_add_two_arrays" , npz_add_two_arrays), &
53
- new_unittest(" npz_add_arr_custom_name" , npz_add_arr_custom_name), &
54
- new_unittest(" npz_add_arr_empty_name" , npz_add_arr_empty_name, should_fail= .true. ), &
55
- new_unittest(" npz_add_arr_duplicate_names" , npz_add_arr_duplicate_names, should_fail= .true. ), &
56
- new_unittest(" npz_save_empty_array_input" , npz_save_empty_array_input, should_fail= .true. ) &
51
+ new_unittest(" add_array_to_empty" , add_array_to_empty), &
52
+ new_unittest(" add_two_arrays" , add_two_arrays), &
53
+ new_unittest(" add_array_custom_name" , add_array_custom_name), &
54
+ new_unittest(" add_array_empty_name" , add_array_empty_name, should_fail= .true. ), &
55
+ new_unittest(" add_array_duplicate_names" , add_array_duplicate_names, should_fail= .true. ), &
56
+ new_unittest(" npz_save_empty_array_input" , npz_save_empty_array_input, should_fail= .true. ), &
57
+ new_unittest(" npz_save_one_array" , npz_save_one_array), &
58
+ new_unittest(" npz_save_two_arrays" , npz_save_two_arrays) &
57
59
]
58
60
end subroutine collect_np
59
61
@@ -938,7 +940,7 @@ subroutine npz_load_two_arr_iint64_rdp_comp(error)
938
940
end select
939
941
end
940
942
941
- subroutine npz_add_arr_to_empty (error )
943
+ subroutine add_array_to_empty (error )
942
944
type (error_type), allocatable , intent (out ) :: error
943
945
944
946
type (t_array_wrapper), allocatable :: arrays(:)
@@ -966,7 +968,7 @@ subroutine npz_add_arr_to_empty(error)
966
968
end select
967
969
end
968
970
969
- subroutine npz_add_two_arrays (error )
971
+ subroutine add_two_arrays (error )
970
972
type (error_type), allocatable , intent (out ) :: error
971
973
972
974
type (t_array_wrapper), allocatable :: arrays(:)
@@ -1015,7 +1017,7 @@ subroutine npz_add_two_arrays(error)
1015
1017
end select
1016
1018
end
1017
1019
1018
- subroutine npz_add_arr_custom_name (error )
1020
+ subroutine add_array_custom_name (error )
1019
1021
type (error_type), allocatable , intent (out ) :: error
1020
1022
1021
1023
type (t_array_wrapper), allocatable :: arrays(:)
@@ -1044,7 +1046,7 @@ subroutine npz_add_arr_custom_name(error)
1044
1046
end select
1045
1047
end
1046
1048
1047
- subroutine npz_add_arr_empty_name (error )
1049
+ subroutine add_array_empty_name (error )
1048
1050
type (error_type), allocatable , intent (out ) :: error
1049
1051
1050
1052
type (t_array_wrapper), allocatable :: arrays(:)
@@ -1058,7 +1060,7 @@ subroutine npz_add_arr_empty_name(error)
1058
1060
call check(error, stat, " Empty file names are not allowed." )
1059
1061
end
1060
1062
1061
- subroutine npz_add_arr_duplicate_names (error )
1063
+ subroutine add_array_duplicate_names (error )
1062
1064
type (error_type), allocatable , intent (out ) :: error
1063
1065
1064
1066
type (t_array_wrapper), allocatable :: arrays(:)
@@ -1090,6 +1092,129 @@ subroutine npz_save_empty_array_input(error)
1090
1092
call check(error, stat, " Trying to save an empty array fail." )
1091
1093
end
1092
1094
1095
+ subroutine npz_save_one_array (error )
1096
+ type (error_type), allocatable , intent (out ) :: error
1097
+
1098
+ type (t_array_wrapper), allocatable :: arrays(:), arrays_reloaded(:)
1099
+ integer :: stat
1100
+ real (dp), allocatable :: input_array(:,:)
1101
+ character (* ), parameter :: output_file = " one_array.npz"
1102
+
1103
+ allocate (input_array(10 , 4 ))
1104
+ call random_number (input_array)
1105
+ call add_array(arrays, input_array, stat)
1106
+ call check(error, stat, " Error adding an array to the list of arrays." )
1107
+ if (allocated (error)) return
1108
+ call check(error, size (arrays) == 1 , " Array was not added to the list of arrays." )
1109
+ if (allocated (error)) return
1110
+ call save_npz(output_file, arrays, stat)
1111
+ call check(error, stat, " Error saving the array." )
1112
+ if (allocated (error)) then
1113
+ call delete_file(output_file); return
1114
+ end if
1115
+ call check(error, exists(output_file), " Output file does not exist." )
1116
+ if (allocated (error)) then
1117
+ call delete_file(output_file); return
1118
+ end if
1119
+
1120
+ call load_npz(output_file, arrays_reloaded, stat)
1121
+ call check(error, stat, " Error loading the npz file." )
1122
+ if (allocated (error)) then
1123
+ call delete_file(output_file); return
1124
+ end if
1125
+ call check(error, size (arrays_reloaded) == 1 , " Wrong number of arrays." )
1126
+ if (allocated (error)) then
1127
+ call delete_file(output_file); return
1128
+ end if
1129
+ select type (typed_array = > arrays_reloaded(1 )% array)
1130
+ class is (t_array_rdp_2)
1131
+ call check(error, size (typed_array% values), size (input_array), " Array sizes to not match." )
1132
+ if (allocated (error)) then
1133
+ call delete_file(output_file); return
1134
+ end if
1135
+ call check(error, any (abs (typed_array% values - input_array) <= epsilon (1.0_dp )), &
1136
+ " Precision loss when adding array." )
1137
+ if (allocated (error)) then
1138
+ call delete_file(output_file); return
1139
+ end if
1140
+ class default
1141
+ call test_failed(error, " Array is of wrong type." )
1142
+ end select
1143
+ call delete_file(output_file)
1144
+ end
1145
+
1146
+ subroutine npz_save_two_arrays (error )
1147
+ type (error_type), allocatable , intent (out ) :: error
1148
+
1149
+ type (t_array_wrapper), allocatable :: arrays(:), arrays_reloaded(:)
1150
+ integer :: stat
1151
+ real (dp), allocatable :: input_array_1(:,:)
1152
+ complex (dp), allocatable :: input_array_2(:)
1153
+ character (* ), parameter :: output_file = " two_arrays.npz"
1154
+
1155
+ allocate (input_array_1(5 , 6 ))
1156
+ call random_number (input_array_1)
1157
+ input_array_2 = [(1.0_dp , 2.0_dp ), (3.0_dp , 4.0_dp ), (5.0_dp , 6.0_dp )]
1158
+ call add_array(arrays, input_array_1, stat)
1159
+ call check(error, stat, " Error adding array 1 to the list of arrays." )
1160
+ if (allocated (error)) return
1161
+ call add_array(arrays, input_array_2, stat)
1162
+ call check(error, stat, " Error adding array 2 to the list of arrays." )
1163
+ if (allocated (error)) return
1164
+ call check(error, size (arrays) == 2 , " Wrong array size." )
1165
+ if (allocated (error)) return
1166
+ call save_npz(output_file, arrays, stat)
1167
+ call check(error, stat, " Error saving arrays as an npz file." )
1168
+ if (allocated (error)) then
1169
+ call delete_file(output_file); return
1170
+ end if
1171
+ call check(error, exists(output_file), " Output file does not exist." )
1172
+ if (allocated (error)) then
1173
+ call delete_file(output_file); return
1174
+ end if
1175
+
1176
+ call load_npz(output_file, arrays_reloaded, stat)
1177
+ call check(error, stat, " Error loading npz file." )
1178
+ if (allocated (error)) then
1179
+ call delete_file(output_file); return
1180
+ end if
1181
+ call check(error, size (arrays_reloaded) == 2 , " Wrong number of arrays." )
1182
+ if (allocated (error)) then
1183
+ call delete_file(output_file); return
1184
+ end if
1185
+
1186
+ select type (typed_array = > arrays_reloaded(1 )% array)
1187
+ class is (t_array_rdp_2)
1188
+ call check(error, size (typed_array% values), size (input_array_1), " Array sizes to not match." )
1189
+ if (allocated (error)) then
1190
+ call delete_file(output_file); return
1191
+ end if
1192
+ call check(error, any (abs (typed_array% values - input_array_1) <= epsilon (1.0_dp )), &
1193
+ " Precision loss when adding array." )
1194
+ if (allocated (error)) then
1195
+ call delete_file(output_file); return
1196
+ end if
1197
+ class default
1198
+ call test_failed(error, " Array 1 is of wrong type." )
1199
+ end select
1200
+
1201
+ select type (typed_array = > arrays_reloaded(2 )% array)
1202
+ class is (t_array_cdp_2)
1203
+ call check(error, size (typed_array% values), size (input_array_2), " Array sizes to not match." )
1204
+ if (allocated (error)) then
1205
+ call delete_file(output_file); return
1206
+ end if
1207
+ call check(error, any (abs (typed_array% values - input_array_2) <= epsilon (1.0_dp )), &
1208
+ " Precision loss when adding array." )
1209
+ if (allocated (error)) then
1210
+ call delete_file(output_file); return
1211
+ end if
1212
+ class default
1213
+ call test_failed(error, " Array 2 is of wrong type." )
1214
+ end select
1215
+ call delete_file(output_file)
1216
+ end
1217
+
1093
1218
! > Makes sure that we find the file when running both `ctest` and `fpm test`.
1094
1219
function get_path (file ) result(path)
1095
1220
character (* ), intent (in ) :: file
0 commit comments