Skip to content

Commit 25fdd87

Browse files
committed
Fix npz loading
1 parent 8708ff9 commit 25fdd87

File tree

6 files changed

+191
-108
lines changed

6 files changed

+191
-108
lines changed

src/stdlib_array.fypp

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,18 @@ module stdlib_array
1414

1515
public :: trueloc, falseloc
1616

17-
type, public :: t_array_bundle
18-
class(t_array), allocatable :: files(:)
17+
!> Helper class to allocate t_array as an abstract type.
18+
type, public :: t_array_wrapper
19+
class(t_array), allocatable :: array
20+
21+
contains
22+
23+
#:for k1, t1 in KINDS_TYPES
24+
#:for rank in RANKS
25+
generic :: allocate_array => allocate_array_${t1[0]}$${k1}$_${rank}$
26+
procedure :: allocate_array_${t1[0]}$${k1}$_${rank}$
27+
#:endfor
28+
#:endfor
1929
end type
2030

2131
type, abstract, public :: t_array
@@ -32,6 +42,30 @@ module stdlib_array
3242

3343
contains
3444

45+
#:for k1, t1 in KINDS_TYPES
46+
#:for rank in RANKS
47+
!> Allocate an instance of the array within the wrapper.
48+
module subroutine allocate_array_${t1[0]}$${k1}$_${rank}$ (wrapper, array, stat, msg)
49+
class(t_array_wrapper), intent(out) :: wrapper
50+
${t1}$, intent(in) :: array${ranksuffix(rank)}$
51+
integer, intent(out) :: stat
52+
character(len=:), allocatable, intent(out) :: msg
53+
54+
allocate (t_array_${t1[0]}$${k1}$_${rank}$ :: wrapper%array, stat=stat)
55+
if (stat /= 0) then
56+
msg = 'Failed to allocate array.'; return
57+
end if
58+
59+
select type (typed_array => wrapper%array)
60+
class is (t_array_${t1[0]}$${k1}$_${rank}$)
61+
typed_array%values = array
62+
class default
63+
msg = 'Failed to allocate values.'; stat = 1; return
64+
end select
65+
end
66+
#:endfor
67+
#:endfor
68+
3569
!> Version: experimental
3670
!>
3771
!> Return the positions of the true elements in array.
@@ -45,7 +79,7 @@ contains
4579
integer :: loc(count(array))
4680

4781
call logicalloc(loc, array, .true., lbound)
48-
end function trueloc
82+
end
4983

5084
!> Version: experimental
5185
!>

src/stdlib_io_np.fypp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
!> utf8-encoded string, so supports structured types with any unicode field names.
7171
module stdlib_io_np
7272
use stdlib_kinds, only: int8, int16, int32, int64, sp, dp, xdp, qp
73-
use stdlib_array, only: t_array_bundle
73+
use stdlib_array, only: t_array_wrapper
7474
implicit none
7575
private
7676

@@ -121,9 +121,9 @@ module stdlib_io_np
121121
!> Load multiple multidimensional arrays from a (compressed) npz file.
122122
!> ([Specification](../page/specs/stdlib_io.html#load_npz))
123123
interface load_npz
124-
module subroutine load_npz_to_bundle(filename, array_bundle, iostat, iomsg)
124+
module subroutine load_npz_to_arrays(filename, arrays, iostat, iomsg)
125125
character(len=*), intent(in) :: filename
126-
type(t_array_bundle), intent(out) :: array_bundle
126+
type(t_array_wrapper), allocatable, intent(out) :: arrays(:)
127127
integer, intent(out), optional :: iostat
128128
character(len=:), allocatable, intent(out), optional :: iomsg
129129
end
@@ -134,9 +134,9 @@ module stdlib_io_np
134134
!> Save multidimensional arrays to a compressed or an uncompressed npz file.
135135
!> ([Specification](../page/specs/stdlib_io.html#save_npz))
136136
interface save_npz
137-
module subroutine save_npz_from_bundle(filename, array_bundle, compressed, iostat, iomsg)
137+
module subroutine save_npz_from_arrays(filename, arrays, compressed, iostat, iomsg)
138138
character(len=*), intent(in) :: filename
139-
type(t_array_bundle), intent(in) :: array_bundle
139+
type(t_array_wrapper), intent(in) :: arrays(*)
140140
!> If true, the file is saved in compressed format. The default is false.
141141
logical, intent(in), optional :: compressed
142142
integer, intent(out), optional :: iostat

src/stdlib_io_np_load.fypp

Lines changed: 76 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,27 @@ contains
3636

3737
open(newunit=io, file=filename, form="unformatted", access="stream", iostat=stat)
3838
catch: block
39+
character(len=:), allocatable :: this_type
3940
integer, allocatable :: vshape(:)
4041

41-
call verify_npy_file(io, filename, vtype, vshape, rank, stat, msg)
42+
call get_descriptor(io, filename, this_type, vshape, stat, msg)
4243
if (stat /= 0) exit catch
4344

45+
if (this_type /= vtype) then
46+
stat = 1
47+
msg = "File '"//filename//"' contains data of type '"//this_type//"', "//&
48+
& "but expected '"//vtype//"'"
49+
exit catch
50+
end if
51+
52+
if (size(vshape) /= rank) then
53+
stat = 1
54+
msg = "File '"//filename//"' contains data of rank "//&
55+
& to_string(size(vshape))//", but expected "//&
56+
& to_string(rank)
57+
exit catch
58+
end if
59+
4460
call allocate_array(array, vshape, stat)
4561
if (stat /= 0) then
4662
msg = "Failed to allocate array of type '"//vtype//"' "//&
@@ -67,44 +83,6 @@ contains
6783
#:endfor
6884
#:endfor
6985

70-
!> Verify header, type and rank of the npy file.
71-
subroutine verify_npy_file(io, filename, vtype, vshape, rank, stat, msg)
72-
!> Access unit to the npy file.
73-
integer, intent(in) :: io
74-
!> Name of the npy file to load from.
75-
character(len=*), intent(in) :: filename
76-
!> Type of the data stored, retrieved from field `descr`.
77-
character(len=*), intent(in) :: vtype
78-
!> Shape of the stored data, retrieved from field `shape`.
79-
integer, allocatable, intent(out) :: vshape(:)
80-
!> Expected rank of the data.
81-
integer, intent(in) :: rank
82-
!> Status of operation.
83-
integer, intent(out) :: stat
84-
!> Associated error message in case of non-zero status.
85-
character(len=:), allocatable, intent(out) :: msg
86-
87-
character(len=:), allocatable :: this_type
88-
89-
call get_descriptor(io, filename, this_type, vshape, stat, msg)
90-
if (stat /= 0) return
91-
92-
if (this_type /= vtype) then
93-
stat = 1
94-
msg = "File '"//filename//"' contains data of type '"//this_type//"', "//&
95-
& "but expected '"//vtype//"'"
96-
return
97-
end if
98-
99-
if (size(vshape) /= rank) then
100-
stat = 1
101-
msg = "File '"//filename//"' contains data of rank "//&
102-
& to_string(size(vshape))//", but expected "//&
103-
& to_string(rank)
104-
return
105-
end if
106-
end
107-
10886
#:for k1, t1 in KINDS_TYPES
10987
#:for rank in RANKS
11088
module subroutine allocate_array_${t1[0]}$${k1}$_${rank}$(array, vshape, stat)
@@ -126,9 +104,9 @@ contains
126104
!>
127105
!> Load multidimensional arrays from a compressed or uncompressed npz file.
128106
!> ([Specification](../page/specs/stdlib_io.html#load_npz))
129-
module subroutine load_npz_to_bundle(filename, array_bundle, iostat, iomsg)
107+
module subroutine load_npz_to_arrays(filename, arrays, iostat, iomsg)
130108
character(len=*), intent(in) :: filename
131-
type(t_array_bundle), intent(out) :: array_bundle
109+
type(t_array_wrapper), allocatable, intent(out) :: arrays(:)
132110
integer, intent(out), optional :: iostat
133111
character(len=:), allocatable, intent(out), optional :: iomsg
134112

@@ -138,9 +116,9 @@ contains
138116

139117
call unzip(filename, unzipped_bundle, stat, msg)
140118
if (stat == 0) then
141-
call load_raw_to_bundle(unzipped_bundle, array_bundle, stat, msg)
119+
call load_unzipped_bundle_to_arrays(unzipped_bundle, arrays, stat, msg)
142120
else
143-
call identify_problem(filename, stat, msg)
121+
call identify_unzip_problem(filename, stat, msg)
144122
end if
145123

146124
if (present(iostat)) then
@@ -156,64 +134,77 @@ contains
156134
if (present(iomsg) .and. allocated(msg)) call move_alloc(msg, iomsg)
157135
end
158136

159-
module subroutine load_raw_to_bundle(unzipped_bundle, array_bundle, stat, msg)
137+
module subroutine load_unzipped_bundle_to_arrays(unzipped_bundle, arrays, stat, msg)
160138
type(t_unzipped_bundle), intent(in) :: unzipped_bundle
161-
type(t_array_bundle), intent(out) :: array_bundle
139+
type(t_array_wrapper), allocatable, intent(out) :: arrays(:)
162140
integer, intent(out) :: stat
163141
character(len=:), allocatable, intent(out) :: msg
164142

165143
integer :: i, io
144+
integer, allocatable :: vshape(:)
145+
character(len=:), allocatable :: this_type
146+
147+
allocate (arrays(size(unzipped_bundle%files)))
166148

167-
allocate (array_bundle%files(size(unzipped_bundle%files)))
168149
do i = 1, size(unzipped_bundle%files)
169-
array_bundle%files(i)%name = unzipped_bundle%files(i)%name
170-
open (newunit=io, status='scratch', form='unformatted', access='stream', iostat=stat)
171-
if (stat /= 0) return
172-
write (io) unzipped_bundle%files(i)%data
173-
call load_string_to_array(io, unzipped_bundle%files(i), array_bundle%files(i), stat, msg)
174-
close (io, status='delete', iostat=stat)
150+
open (newunit=io, status='scratch', form='unformatted', access='stream', iostat=stat, iomsg=msg)
175151
if (stat /= 0) return
176-
end do
177-
end
178-
179-
module subroutine load_string_to_array(io, unzipped_file, array, stat, msg)
180-
integer, intent(in) :: io
181-
type(t_unzipped_file), intent(in) :: unzipped_file
182-
class(t_array), intent(inout) :: array
183-
integer, intent(out) :: stat
184-
character(len=:), allocatable, intent(out) :: msg
185152

186-
#:for k1, t1 in KINDS_TYPES
187-
#:for rank in RANKS
188-
${t1}$, allocatable :: array_${t1[0]}$${k1}$_${rank}$${ranksuffix(rank)}$
189-
#:endfor
190-
#:endfor
153+
write (io, iostat=stat) unzipped_bundle%files(i)%data
154+
if (stat /= 0) then
155+
msg = 'Failed to write unzipped data to scratch file.'
156+
close (io, status='delete'); return
157+
end if
191158

192-
integer, allocatable :: vshape(:)
159+
rewind (io)
160+
call get_descriptor(io, unzipped_bundle%files(i)%name, this_type, vshape, stat, msg)
161+
if (stat /= 0) return
193162

194-
select type (arr => array)
163+
select case (this_type)
195164
#:for k1, t1 in KINDS_TYPES
165+
case (type_${t1[0]}$${k1}$)
166+
select case (size(vshape))
196167
#:for rank in RANKS
197-
type is (t_array_${t1[0]}$${k1}$_${rank}$)
198-
call verify_npy_file(io, unzipped_file%name, type_${t1[0]}$${k1}$, vshape, ${rank}$, stat, msg)
199-
if (stat /= 0) return
200-
call allocate_array(array_${t1[0]}$${k1}$_${rank}$, vshape, stat)
201-
if (stat /= 0) then
202-
msg = "Failed to allocate array of type '"//type_${t1[0]}$${k1}$//"' "//&
203-
& "with total size of "//to_string(product(vshape))
204-
return
205-
end if
206-
read (io, iostat=stat) array_${t1[0]}$${k1}$_${rank}$${ranksuffix(rank)}$
207-
arr%values = array_${t1[0]}$${k1}$_${rank}$${ranksuffix(rank)}$
168+
case (${rank}$)
169+
block
170+
${t1}$, allocatable :: array${ranksuffix(rank)}$
171+
172+
call allocate_array(array, vshape, stat)
173+
if (stat /= 0) then
174+
msg = "Failed to allocate array of type '"//this_type//"'."; return
175+
end if
176+
177+
read (io, iostat=stat) array
178+
if (stat /= 0) then
179+
msg = "Failed to read array of type '"//this_type//"' "//&
180+
& 'with total size of '//to_string(product(vshape)); return
181+
end if
182+
183+
call arrays(i)%allocate_array(array, stat, msg)
184+
if (stat /= 0) then
185+
msg = "Failed to allocate array of type '"//this_type//"' "//&
186+
& 'with total size of '//to_string(product(vshape)); return
187+
end if
188+
189+
arrays(i)%array%name = unzipped_bundle%files(i)%name
190+
end block
208191
#:endfor
192+
case default
193+
stat = 1; msg = 'Unsupported rank for array of type '//this_type//': '// &
194+
& to_string(size(vshape))//'.'; return
195+
end select
209196
#:endfor
210-
class default
211-
stat = 1; msg = 'Unsupported array type.'; return
212-
end select
197+
case default
198+
stat = 1; msg = 'Unsupported array type: '//this_type//'.'; return
199+
end select
200+
201+
close (io, status='delete')
202+
if (stat /= 0) return
203+
end do
213204
end
214205

215-
!> Open file and try to identify the problem.
216-
module subroutine identify_problem(filename, stat, msg)
206+
!> Open file and try to identify the cause of the error that occurred during unzip.
207+
module subroutine identify_unzip_problem(filename, stat, msg)
217208
character(len=*), intent(in) :: filename
218209
integer, intent(inout) :: stat
219210
character(len=:), allocatable, intent(inout) :: msg
@@ -291,7 +282,7 @@ contains
291282

292283
! stat should be zero if no error occurred
293284
stat = 0
294-
285+
295286
read(io, iostat=stat) header
296287
if (stat /= 0) return
297288

src/stdlib_io_np_save.fypp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ contains
135135
!>
136136
!> Save multidimensional arrays to a compressed or an uncompressed npz file.
137137
!> ([Specification](../page/specs/stdlib_io.html#save_npz))
138-
module subroutine save_npz_from_bundle(filename, array_bundle, compressed, iostat, iomsg)
138+
module subroutine save_npz_from_arrays(filename, arrays, compressed, iostat, iomsg)
139139
character(len=*), intent(in) :: filename
140-
type(t_array_bundle), intent(in) :: array_bundle
140+
type(t_array_wrapper), intent(in) :: arrays(*)
141141
!> If true, the file is saved in compressed format. The default is false.
142142
logical, intent(in), optional :: compressed
143143
integer, intent(out), optional :: iostat

src/stdlib_io_zip.f90

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
module stdlib_io_zip
2-
use stdlib_array, only: t_array_bundle
32
use stdlib_io_minizip
43
use iso_c_binding, only: c_ptr, c_associated, c_int, c_long, c_char
54
implicit none
@@ -13,7 +12,7 @@ module stdlib_io_zip
1312
integer(kind=c_long), parameter :: buffer_size = 1024
1413

1514
interface unzip
16-
module procedure unzip_to_raw
15+
module procedure unzip_to_bundle
1716
end interface
1817

1918
!> Contains extracted raw data from a zip file.
@@ -32,7 +31,7 @@ module stdlib_io_zip
3231

3332
contains
3433

35-
module subroutine unzip_to_raw(filename, bundle, iostat, iomsg)
34+
module subroutine unzip_to_bundle(filename, bundle, iostat, iomsg)
3635
character(len=*), intent(in) :: filename
3736
type(t_unzipped_bundle), intent(out) :: bundle
3837
integer, intent(out), optional :: iostat

0 commit comments

Comments
 (0)