Skip to content

Commit e8d140a

Browse files
committed
Multi-file checkpoint format
1 parent f2e0e1e commit e8d140a

File tree

5 files changed

+261
-101
lines changed

5 files changed

+261
-101
lines changed

elements/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = '3.17.0'
1+
__version__ = '3.18.1'
22

33
from .agg import Agg
44
from .checkpoint import Checkpoint, Saveable
@@ -17,6 +17,7 @@
1717
from .utils import timestamp
1818
from .uuid import UUID
1919

20+
from . import checkpoint
2021
from . import logger
2122
from . import plotting
2223
from . import timer

elements/checkpoint.py

Lines changed: 137 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
import concurrent.futures
1+
import inspect
22
import pickle
3-
import time
43

4+
from . import path as pathlib
55
from . import printing
6-
from . import path
76
from . import timer
7+
from . import utils
88

99

1010
class Saveable:
@@ -35,97 +35,167 @@ def load(self, data):
3535

3636
class Checkpoint:
3737

38-
def __init__(self, filename=None, parallel=True, write=True):
39-
self._filename = filename and path.Path(filename)
40-
self._values = {}
41-
self._parallel = parallel
38+
"""
39+
Checkpoints are stored in this file structure:
40+
41+
directory/
42+
latest # Contains folder name of latest complete save.
43+
<timestamp>-<step>/
44+
foo.pkl
45+
bar.pkl
46+
baz-0.pkl
47+
baz-1.pkl
48+
baz-2.pkl
49+
done # Empty file marking the save as complete.
50+
...
51+
"""
52+
53+
def __init__(self, directory=None, keep=1, step=None, write=True):
54+
assert keep is None or keep >= 1
55+
self._directory = directory and pathlib.Path(directory)
56+
self._keep = keep
57+
self._step = step
4258
self._write = write
43-
self._promise = None
44-
if self._parallel:
45-
self._worker = concurrent.futures.ThreadPoolExecutor(1, 'checkpoint')
59+
self._saveables = {}
4660

4761
def __setattr__(self, name, value):
48-
if name in ('exists', 'save', 'load'):
49-
return super().__setattr__(name, value)
5062
if name.startswith('_'):
5163
return super().__setattr__(name, value)
5264
has_load = hasattr(value, 'load') and callable(value.load)
5365
has_save = hasattr(value, 'save') and callable(value.save)
5466
if not (has_load and has_save):
55-
message = f"Checkpoint entry '{name}' must implement save() and load()."
56-
raise ValueError(message)
57-
self._values[name] = value
67+
raise ValueError(
68+
f"Checkpointed object '{name}' must implement save() and load().")
69+
self._saveables[name] = value
5870

5971
def __getattr__(self, name):
6072
if name.startswith('_'):
6173
raise AttributeError(name)
6274
try:
63-
return self._values[name]
75+
return self._saveables[name]
6476
except AttributeError:
6577
raise ValueError(name)
6678

67-
def exists(self, filename=None):
68-
assert self._filename or filename
69-
filename = path.Path(filename or self._filename)
70-
exists = self._filename.exists()
71-
if exists:
79+
def exists(self, path=None):
80+
assert self._directory or path
81+
if path:
82+
result = exists(path)
83+
else:
84+
result = bool(self.latest())
85+
if result:
7286
print('Found existing checkpoint.')
7387
else:
7488
print('Did not find any checkpoint.')
75-
return exists
76-
77-
def save(self, filename=None, keys=None):
78-
assert self._filename or filename
79-
filename = path.Path(filename or self._filename)
80-
printing.print_(f'Saving checkpoint: {filename}')
81-
keys = tuple(self._values.keys() if keys is None else keys)
82-
assert all([not k.startswith('_') for k in keys]), keys
83-
data = {k: self._values[k].save() for k in keys}
84-
if not self._write:
85-
return
86-
if self._parallel:
87-
self._promise and self._promise.result()
88-
self._promise = self._worker.submit(self._save, filename, data)
89-
else:
90-
self._save(filename, data)
89+
return result
9190

9291
@timer.section('checkpoint_save')
93-
def _save(self, filename, data):
94-
data['_timestamp'] = time.time()
95-
filename.parent.mkdir()
96-
content = pickle.dumps(data)
97-
if str(filename).startswith('gs://'):
98-
filename.write(content, mode='wb')
92+
def save(self, path=None, keys=None):
93+
assert self._directory or path
94+
if keys is None:
95+
savefns = {k: v.save for k, v in self._saveables.items()}
9996
else:
100-
# Write to a temporary file and then atomically rename, so that the file
101-
# either contains a complete write or not update at all if writing was
102-
# interrupted.
103-
tmp = filename.parent / (filename.name + '.tmp')
104-
tmp.write(content, mode='wb')
105-
tmp.move(filename)
106-
print('Wrote checkpoint.')
97+
assert all([not k.startswith('_') for k in keys]), keys
98+
savefns = {k: self._saveables[k].save for k in keys}
99+
if path:
100+
folder = None
101+
else:
102+
folder = utils.timestamp(millis=True)
103+
if self._step is not None:
104+
folder += f'-{int(self._step):012d}'
105+
path = self._directory / folder
106+
printing.print_(f'Saving checkpoint: {path}')
107+
save(path, savefns, self._write)
108+
if folder and self._write:
109+
(self._directory / 'latest').write_text(folder)
110+
self._cleanup()
111+
print('Saved checkpoint.')
107112

108113
@timer.section('checkpoint_load')
109-
def load(self, filename=None, keys=None):
110-
assert self._filename or filename
111-
self._promise and self._promise.result() # Wait for last save.
112-
filename = path.Path(filename or self._filename)
113-
printing.print_(f'Loading checkpoint: {filename}')
114-
data = pickle.loads(filename.read('rb'))
115-
keys = tuple(data.keys() if keys is None else keys)
116-
for key in keys:
117-
if key.startswith('_'):
118-
continue
119-
try:
120-
self._values[key].load(data[key])
121-
except Exception:
122-
print(f"Error loading '{key}' from checkpoint.")
123-
raise
124-
age = time.time() - data['_timestamp']
125-
printing.print_(f'Loaded checkpoint from {age:.0f} seconds ago.')
114+
def load(self, path=None, keys=None):
115+
assert self._directory or path
116+
if keys is None:
117+
loadfns = {k: v.load for k, v in self._saveables.items()}
118+
else:
119+
assert all([not k.startswith('_') for k in keys]), keys
120+
loadfns = {k: self._saveables[k].load for k in keys}
121+
if not path:
122+
path = self.latest()
123+
assert path
124+
printing.print_(f'Loading checkpoint: {path}')
125+
load(path, loadfns)
126+
print('Loaded checkpoint.')
126127

127128
def load_or_save(self):
128129
if self.exists():
129130
self.load()
130131
else:
131132
self.save()
133+
134+
def latest(self):
135+
filename = (self._directory / 'latest')
136+
if not filename.exists():
137+
return None
138+
return self._directory / filename.read_text()
139+
140+
def _cleanup(self):
141+
if not self._keep:
142+
return
143+
folders = self._directory.glob('*')
144+
folders = [x for x in folders if x.name != 'latest']
145+
old = sorted(folders)[:-self._keep]
146+
for folder in old:
147+
folder.remove(recursive=True)
148+
149+
150+
def exists(path):
151+
path = pathlib.Path(path)
152+
return (path / 'done').exists()
153+
154+
155+
def save(path, savefns, write=True):
156+
path = pathlib.Path(path)
157+
assert not exists(path), path
158+
write and path.mkdir(parents=True)
159+
for name, savefn in savefns.items():
160+
try:
161+
data = savefn()
162+
if inspect.isgenerator(data):
163+
for i, shard in enumerate(data):
164+
assert i < 1e5, i
165+
if write: # Iterate even if we're not writing.
166+
buffer = pickle.dumps(shard)
167+
(path / f'{name}-{i:04d}.pkl').write_bytes(buffer)
168+
else:
169+
if write:
170+
buffer = pickle.dumps(data)
171+
(path / f'{name}.pkl').write_bytes(buffer)
172+
except Exception:
173+
print(f"Error save '{name}' to checkpoint.")
174+
raise
175+
write and (path / 'done').write_bytes(b'')
176+
177+
178+
def load(path, loadfns):
179+
path = pathlib.Path(path)
180+
assert exists(path), path
181+
filenames = set(path.glob('*'))
182+
for name, loadfn in loadfns.items():
183+
try:
184+
if (path / f'{name}.pkl') in filenames:
185+
buffer = (path / f'{name}.pkl').read_bytes()
186+
data = pickle.loads(buffer)
187+
loadfn(data)
188+
elif (path / f'{name}-0000.pkl') in filenames:
189+
shards = [x for x in filenames if x.name.startswith(f'{name}-')]
190+
shards = sorted(shards)
191+
def generator():
192+
for filename in shards:
193+
buffer = filename.read_bytes()
194+
data = pickle.loads(buffer)
195+
yield data
196+
loadfn(generator())
197+
else:
198+
raise KeyError(name)
199+
except Exception:
200+
print(f"Error loading '{name}' from checkpoint.")
201+
raise

elements/path.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def __lt__(self, other):
5555
def __str__(self):
5656
return self._path
5757

58+
def __hash__(self):
59+
return hash(str(self))
60+
5861
@property
5962
def parent(self):
6063
if '/' not in self._path:

elements/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from datetime import datetime
22

33

4-
def timestamp(now=None, millis=False):
5-
now = datetime.now() if now is None else now
6-
string = now.strftime("%Y%m%dT%H%M%S")
4+
def timestamp(time=None, millis=False):
5+
if time is None:
6+
time = datetime.now()
7+
string = time.strftime("%Y%m%dT%H%M%S")
78
if millis:
8-
string += f'F{now.microsecond:06d}'
9+
string += f'F{time.microsecond:06d}'
910
return string

0 commit comments

Comments
 (0)