@@ -153,15 +153,24 @@ def dump(self, **kwargs):
153153
154154 def dump_each (self , ** kwargs ):
155155 """Write data to NPY files without considering `self.stride`."""
156- for key , value in kwargs .items ():
157- # Check array properties
158- shape , dtype = self .fields .get (key , (None , None ))
159- arvalue = np .asarray (value )
160- if shape is None :
161- shape = arvalue .shape
162- dtype = arvalue .dtype
163- self .fields [key ] = (shape , dtype )
164- else :
156+ converted = {}
157+ if len (self .fields ) == 0 :
158+ # No checking, just record the given shapes and types
159+ for key , value in kwargs .items ():
160+ arvalue = np .asarray (value )
161+ converted [key ] = arvalue
162+ self .fields [key ] = (arvalue .shape , arvalue .dtype )
163+ else :
164+ # Check kwargs
165+ if set (self .fields ) != set (kwargs ):
166+ raise TypeError (
167+ f"Received keys: { list (kwargs .keys ())} . "
168+ f"Expected: { list (self .fields .keys ())} "
169+ )
170+ for key , value in kwargs .items ():
171+ arvalue = np .asarray (value )
172+ converted [key ] = arvalue
173+ shape , dtype = self .fields [key ]
165174 if shape != arvalue .shape :
166175 raise TypeError (
167176 f"The shape of { key } , { arvalue .shape } , differs from the first one, { shape } "
@@ -170,7 +179,9 @@ def dump_each(self, **kwargs):
170179 raise TypeError (
171180 f"The dtype of { key } , { arvalue .dtype } , differs from the first one, { dtype } "
172181 )
173- # Append to NPY file
182+
183+ # Write only once all checks have passed
184+ for key , value in converted .items ():
174185 path = os .path .join (self .dir_out , f"{ key } .npy" )
175186 with NpyAppendArray (path , delete_if_exists = False ) as npaa :
176- npaa .append (arvalue .reshape (1 , * arvalue .shape ))
187+ npaa .append (value .reshape (1 , * value .shape ))
0 commit comments