Skip to content

Commit 61d4334

Browse files
DWeslMaanasArora
authored andcommitted
PERF: Use dict instead of list to make NpzFile member existence checks constant time (numpy#29098)
Use dict instead of list to convert the passed key to the name used in the archive.
1 parent c95e2ae commit 61d4334

File tree

1 file changed

+29
-35
lines changed

1 file changed

+29
-35
lines changed

numpy/lib/_npyio_impl.py

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -195,16 +195,13 @@ def __init__(self, fid, own_fid=False, allow_pickle=False,
195195
# Import is postponed to here since zipfile depends on gzip, an
196196
# optional component of the so-called standard library.
197197
_zip = zipfile_factory(fid)
198-
self._files = _zip.namelist()
199-
self.files = []
198+
_files = _zip.namelist()
199+
self.files = [name.removesuffix(".npy") for name in _files]
200+
self._files = dict(zip(self.files, _files))
201+
self._files.update(zip(_files, _files))
200202
self.allow_pickle = allow_pickle
201203
self.max_header_size = max_header_size
202204
self.pickle_kwargs = pickle_kwargs
203-
for x in self._files:
204-
if x.endswith('.npy'):
205-
self.files.append(x[:-4])
206-
else:
207-
self.files.append(x)
208205
self.zip = _zip
209206
self.f = BagObj(self)
210207
if own_fid:
@@ -240,37 +237,34 @@ def __len__(self):
240237
return len(self.files)
241238

242239
def __getitem__(self, key):
243-
# FIXME: This seems like it will copy strings around
244-
# more than is strictly necessary. The zipfile
245-
# will read the string and then
246-
# the format.read_array will copy the string
247-
# to another place in memory.
248-
# It would be better if the zipfile could read
249-
# (or at least uncompress) the data
250-
# directly into the array memory.
251-
member = False
252-
if key in self._files:
253-
member = True
254-
elif key in self.files:
255-
member = True
256-
key += '.npy'
257-
if member:
258-
bytes = self.zip.open(key)
259-
magic = bytes.read(len(format.MAGIC_PREFIX))
260-
bytes.close()
261-
if magic == format.MAGIC_PREFIX:
262-
bytes = self.zip.open(key)
263-
return format.read_array(bytes,
264-
allow_pickle=self.allow_pickle,
265-
pickle_kwargs=self.pickle_kwargs,
266-
max_header_size=self.max_header_size)
267-
else:
268-
return self.zip.read(key)
240+
try:
241+
key = self._files[key]
242+
except KeyError:
243+
raise KeyError(f"{key} is not a file in the archive") from None
269244
else:
270-
raise KeyError(f"{key} is not a file in the archive")
245+
with self.zip.open(key) as bytes:
246+
magic = bytes.read(len(format.MAGIC_PREFIX))
247+
bytes.seek(0)
248+
if magic == format.MAGIC_PREFIX:
249+
# FIXME: This seems like it will copy strings around
250+
# more than is strictly necessary. The zipfile
251+
# will read the string and then
252+
# the format.read_array will copy the string
253+
# to another place in memory.
254+
# It would be better if the zipfile could read
255+
# (or at least uncompress) the data
256+
# directly into the array memory.
257+
return format.read_array(
258+
bytes,
259+
allow_pickle=self.allow_pickle,
260+
pickle_kwargs=self.pickle_kwargs,
261+
max_header_size=self.max_header_size
262+
)
263+
else:
264+
return bytes.read(key)
271265

272266
def __contains__(self, key):
273-
return (key in self._files or key in self.files)
267+
return (key in self._files)
274268

275269
def __repr__(self):
276270
# Get filename or default to `object`

0 commit comments

Comments
 (0)