5
5
#:set KINDS_TYPES = REAL_KINDS_TYPES + INT_KINDS_TYPES + CMPLX_KINDS_TYPES
6
6
7
7
!> Implementation of loading npy files into multidimensional arrays
8
- submodule (stdlib_io_np) stdlib_io_npy_load
9
- use stdlib_error, only : error_stop
10
- use stdlib_strings, only : to_string, starts_with
8
+ submodule(stdlib_io_np) stdlib_io_np_load
9
+ use stdlib_error, only: error_stop
10
+ use stdlib_strings, only: to_string, starts_with
11
+ use stdlib_string_type, only: string_type
12
+ use stdlib_io_zip, only: unzip, zip_prefix, zip_suffix, t_unzipped_bundle, t_unzipped_file
13
+ use stdlib_array
11
14
implicit none
12
15
13
16
contains
@@ -33,28 +36,12 @@ contains
33
36
34
37
open(newunit=io, file=filename, form="unformatted", access="stream", iostat=stat)
35
38
catch: block
36
- character(len=:), allocatable :: this_type
37
39
integer, allocatable :: vshape(:)
38
40
39
- call get_descriptor (io, filename, this_type , vshape, stat, msg)
41
+ call verify_npy_file (io, filename, vtype , vshape, rank , stat, msg)
40
42
if (stat /= 0) exit catch
41
43
42
- if (this_type /= vtype) then
43
- stat = 1
44
- msg = "File '"//filename//"' contains data of type '"//this_type//"', "//&
45
- & "but expected '"//vtype//"'"
46
- exit catch
47
- end if
48
-
49
- if (size(vshape) /= rank) then
50
- stat = 1
51
- msg = "File '"//filename//"' contains data of rank "//&
52
- & to_string(size(vshape))//", but expected "//&
53
- & to_string(rank)
54
- exit catch
55
- end if
56
-
57
- call allocator(array, vshape, stat)
44
+ call allocate_array(array, vshape, stat)
58
45
if (stat /= 0) then
59
46
msg = "Failed to allocate array of type '"//vtype//"' "//&
60
47
& "with total size of "//to_string(product(vshape))
@@ -76,30 +63,210 @@ contains
76
63
end if
77
64
78
65
if (present(iomsg).and.allocated(msg)) call move_alloc(msg, iomsg)
79
- contains
66
+ end
67
+ #:endfor
68
+ #:endfor
80
69
81
- !> Wrapped intrinsic allocate to create an allocation from a shape array
82
- subroutine allocator(array, vshape, stat)
83
- !> Instance of the array to be allocated
84
- ${t1}$, allocatable, intent(out) :: array${ranksuffix(rank)}$
85
- !> Dimensions to allocate for
86
- integer, intent(in) :: vshape(:)
87
- !> Status of allocate
70
+ !> Verify header, type and rank of the npy file.
71
+ subroutine verify_npy_file(io, filename, vtype, vshape, rank, stat, msg)
72
+ !> Access unit to the npy file.
73
+ integer, intent(in) :: io
74
+ !> Name of the npy file to load from.
75
+ character(len=*), intent(in) :: filename
76
+ !> Type of the data stored, retrieved from field `descr`.
77
+ character(len=*), intent(in) :: vtype
78
+ !> Shape of the stored data, retrieved from field `shape`.
79
+ integer, allocatable, intent(out) :: vshape(:)
80
+ !> Expected rank of the data.
81
+ integer, intent(in) :: rank
82
+ !> Status of operation.
88
83
integer, intent(out) :: stat
84
+ !> Associated error message in case of non-zero status.
85
+ character(len=:), allocatable, intent(out) :: msg
86
+
87
+ character(len=:), allocatable :: this_type
88
+
89
+ call get_descriptor(io, filename, this_type, vshape, stat, msg)
90
+ if (stat /= 0) return
91
+
92
+ if (this_type /= vtype) then
93
+ stat = 1
94
+ msg = "File '"//filename//"' contains data of type '"//this_type//"', "//&
95
+ & "but expected '"//vtype//"'"
96
+ return
97
+ end if
98
+
99
+ if (size(vshape) /= rank) then
100
+ stat = 1
101
+ msg = "File '"//filename//"' contains data of rank "//&
102
+ & to_string(size(vshape))//", but expected "//&
103
+ & to_string(rank)
104
+ return
105
+ end if
106
+ end
107
+
108
+ #:for k1, t1 in KINDS_TYPES
109
+ #:for rank in RANKS
110
+ module subroutine allocate_array_${t1[0]}$${k1}$_${rank}$(array, vshape, stat)
111
+ ${t1}$, allocatable, intent(out) :: array${ranksuffix(rank)}$
112
+ integer, intent(in) :: vshape(:)
113
+ integer, intent(out) :: stat
114
+
115
+ allocate(array( &
116
+ #:for i in range(rank-1)
117
+ & vshape(${i+1}$), &
118
+ #:endfor
119
+ & vshape(${rank}$)), &
120
+ & stat=stat)
121
+ end
122
+ #:endfor
123
+ #:endfor
124
+
125
+ !> Version: experimental
126
+ !>
127
+ !> Load multidimensional arrays from a compressed or uncompressed npz file.
128
+ !> ([Specification](../page/specs/stdlib_io.html#load_npz))
129
+ module subroutine load_npz_to_bundle(filename, array_bundle, iostat, iomsg)
130
+ character(len=*), intent(in) :: filename
131
+ type(t_array_bundle), intent(out) :: array_bundle
132
+ integer, intent(out), optional :: iostat
133
+ character(len=:), allocatable, intent(out), optional :: iomsg
89
134
90
- allocate(array( &
91
- #:for i in range(rank-1)
92
- & vshape(${i+1}$), &
135
+ type(t_unzipped_bundle) :: unzipped_bundle
136
+ integer :: stat
137
+ character(len=:), allocatable :: msg
138
+
139
+ call unzip(filename, unzipped_bundle, stat, msg)
140
+ if (stat == 0) then
141
+ call load_raw_to_bundle(unzipped_bundle, array_bundle, stat, msg)
142
+ else
143
+ call identify_problem(filename, stat, msg)
144
+ end if
145
+
146
+ if (present(iostat)) then
147
+ iostat = stat
148
+ else if (stat /= 0) then
149
+ if (allocated(msg)) then
150
+ call error_stop("Failed to read arrays from file '"//filename//"'"//nl//msg)
151
+ else
152
+ call error_stop("Failed to read arrays from file '"//filename//"'")
153
+ end if
154
+ end if
155
+
156
+ if (present(iomsg) .and. allocated(msg)) call move_alloc(msg, iomsg)
157
+ end
158
+
159
+ module subroutine load_raw_to_bundle(unzipped_bundle, array_bundle, stat, msg)
160
+ type(t_unzipped_bundle), intent(in) :: unzipped_bundle
161
+ type(t_array_bundle), intent(out) :: array_bundle
162
+ integer, intent(out) :: stat
163
+ character(len=:), allocatable, intent(out) :: msg
164
+
165
+ integer :: i, io
166
+
167
+ allocate (array_bundle%files(size(unzipped_bundle%files)))
168
+ do i = 1, size(unzipped_bundle%files)
169
+ array_bundle%files(i)%name = unzipped_bundle%files(i)%name
170
+ open (newunit=io, status='scratch', form='unformatted', access='stream', iostat=stat)
171
+ if (stat /= 0) return
172
+ write (io) unzipped_bundle%files(i)%data
173
+ call load_string_to_array(io, unzipped_bundle%files(i), array_bundle%files(i), stat, msg)
174
+ close (io, status='delete', iostat=stat)
175
+ if (stat /= 0) return
176
+ end do
177
+ end
178
+
179
+ module subroutine load_string_to_array(io, unzipped_file, array, stat, msg)
180
+ integer, intent(in) :: io
181
+ type(t_unzipped_file), intent(in) :: unzipped_file
182
+ class(t_array), intent(inout) :: array
183
+ integer, intent(out) :: stat
184
+ character(len=:), allocatable, intent(out) :: msg
185
+
186
+ #:for k1, t1 in KINDS_TYPES
187
+ #:for rank in RANKS
188
+ ${t1}$, allocatable :: array_${t1[0]}$${k1}$_${rank}$${ranksuffix(rank)}$
93
189
#:endfor
94
- & vshape(${rank}$)), &
95
- & stat=stat)
190
+ #:endfor
96
191
97
- end subroutine allocator
192
+ integer, allocatable :: vshape(:)
98
193
99
- end subroutine load_npy_${t1[0]}$${k1}$_${rank}$
100
- #:endfor
101
- #:endfor
194
+ select type (arr => array)
195
+ #:for k1, t1 in KINDS_TYPES
196
+ #:for rank in RANKS
197
+ type is (t_array_${t1[0]}$${k1}$_${rank}$)
198
+ call verify_npy_file(io, unzipped_file%name, type_${t1[0]}$${k1}$, vshape, ${rank}$, stat, msg)
199
+ if (stat /= 0) return
200
+ call allocate_array(array_${t1[0]}$${k1}$_${rank}$, vshape, stat)
201
+ if (stat /= 0) then
202
+ msg = "Failed to allocate array of type '"//type_${t1[0]}$${k1}$//"' "//&
203
+ & "with total size of "//to_string(product(vshape))
204
+ return
205
+ end if
206
+ read (io, iostat=stat) array_${t1[0]}$${k1}$_${rank}$${ranksuffix(rank)}$
207
+ arr%values = array_${t1[0]}$${k1}$_${rank}$${ranksuffix(rank)}$
208
+ #:endfor
209
+ #:endfor
210
+ class default
211
+ stat = 1; msg = 'Unsupported array type.'; return
212
+ end select
213
+ end
214
+
215
+ !> Open file and try to identify the problem.
216
+ module subroutine identify_problem(filename, stat, msg)
217
+ character(len=*), intent(in) :: filename
218
+ integer, intent(inout) :: stat
219
+ character(len=:), allocatable, intent(inout) :: msg
102
220
221
+ logical :: exists
222
+ integer :: io_unit, prev_stat
223
+ character(len=:), allocatable :: prev_msg
224
+
225
+ ! Keep track of the previous status and message in case no reason can be found.
226
+ prev_stat = stat
227
+ if (allocated(msg)) call move_alloc(msg, prev_msg)
228
+
229
+ inquire (file=filename, exist=exists)
230
+ if (.not. exists) then
231
+ stat = 1; msg = 'File does not exist: '//filename//'.'; return
232
+ end if
233
+ open (newunit=io_unit, file=filename, form='unformatted', access='stream', &
234
+ & status='old', action='read', iostat=stat, iomsg=msg)
235
+ if (stat /= 0) return
236
+
237
+ call verify_header(io_unit, stat, msg)
238
+ if (stat /= 0) return
239
+
240
+ ! Restore previous status and message if no reason could be found.
241
+ stat = prev_stat; msg = 'Failed to unzip file: '//filename//nl//prev_msg
242
+ end
243
+
244
+ module subroutine verify_header(io_unit, stat, msg)
245
+ integer, intent(in) :: io_unit
246
+ integer, intent(out) :: stat
247
+ character(len=:), allocatable, intent(out) :: msg
248
+
249
+ integer :: file_size
250
+ character(len=len(zip_prefix)) :: header
251
+
252
+ inquire (io_unit, size=file_size)
253
+ if (file_size < len(zip_suffix)) then
254
+ stat = 1; msg = 'File is too small to be an npz file.'; return
255
+ end if
256
+
257
+ read (io_unit, iostat=stat) header
258
+ if (stat /= 0) then
259
+ msg = 'Failed to read header from file'; return
260
+ end if
261
+
262
+ if (header == zip_suffix) then
263
+ stat = 1; msg = 'Empty npz file.'; return
264
+ end if
265
+
266
+ if (header /= zip_prefix) then
267
+ stat = 1; msg = 'Not an npz file.'; return
268
+ end if
269
+ end
103
270
104
271
!> Read the npy header from a binary file and retrieve the descriptor string.
105
272
subroutine get_descriptor(io, filename, vtype, vshape, stat, msg)
@@ -168,7 +335,7 @@ contains
168
335
if (.not.fortran_order) then
169
336
vshape = [(vshape(i), i = size(vshape), 1, -1)]
170
337
end if
171
- end subroutine get_descriptor
338
+ end
172
339
173
340
174
341
!> Parse the first eight bytes of the npy header to verify the data
@@ -214,7 +381,7 @@ contains
214
381
& "'"//to_string(major)//"."//to_string(minor)//"'"
215
382
return
216
383
end if
217
- end subroutine parse_header
384
+ end
218
385
219
386
!> Parse the descriptor in the npy header. This routine implements a minimal
220
387
!> non-recursive parser for serialized Python dictionaries.
@@ -367,7 +534,7 @@ contains
367
534
& "1 | " // input // nl // &
368
535
& " |" // repeat(" ", first) // repeat("^", last - first + 1) // nl // &
369
536
& " |"
370
- end function make_message
537
+ end
371
538
372
539
!> Parse a tuple of integers into an array of integers
373
540
subroutine parse_tuple(input, pos, tuple, stat, msg)
@@ -427,7 +594,7 @@ contains
427
594
return
428
595
end select
429
596
end do
430
- end subroutine parse_tuple
597
+ end
431
598
432
599
!> Get the next allowed token
433
600
subroutine next_token(input, pos, token, allowed_token, stat, msg)
@@ -459,7 +626,7 @@ contains
459
626
exit
460
627
end if
461
628
end do
462
- end subroutine next_token
629
+ end
463
630
464
631
!> Tokenize input string
465
632
subroutine get_token(input, pos, token)
@@ -531,8 +698,8 @@ contains
531
698
token = token_type(pos, pos, invalid)
532
699
end select
533
700
534
- end subroutine get_token
701
+ end
535
702
536
- end subroutine parse_descriptor
703
+ end
537
704
538
- end submodule stdlib_io_npy_load
705
+ end
0 commit comments