|
1 | | -import concurrent.futures |
| 1 | +import inspect |
2 | 2 | import pickle |
3 | | -import time |
4 | 3 |
|
| 4 | +from . import path as pathlib |
5 | 5 | from . import printing |
6 | | -from . import path |
7 | 6 | from . import timer |
| 7 | +from . import utils |
8 | 8 |
|
9 | 9 |
|
10 | 10 | class Saveable: |
@@ -35,97 +35,167 @@ def load(self, data): |
35 | 35 |
|
36 | 36 | class Checkpoint: |
37 | 37 |
|
38 | | - def __init__(self, filename=None, parallel=True, write=True): |
39 | | - self._filename = filename and path.Path(filename) |
40 | | - self._values = {} |
41 | | - self._parallel = parallel |
| 38 | + """ |
| 39 | + Checkpoints are stored in this file structure: |
| 40 | +
|
| 41 | + directory/ |
| 42 | + latest # Contains folder name of latest complete save. |
| 43 | + <timestamp>-<step>/ |
| 44 | + foo.pkl |
| 45 | + bar.pkl |
| 46 | + baz-0.pkl |
| 47 | + baz-1.pkl |
| 48 | + baz-2.pkl |
| 49 | + done # Empty file marking the save as complete. |
| 50 | + ... |
| 51 | + """ |
| 52 | + |
| 53 | + def __init__(self, directory=None, keep=1, step=None, write=True): |
| 54 | + assert keep is None or keep >= 1 |
| 55 | + self._directory = directory and pathlib.Path(directory) |
| 56 | + self._keep = keep |
| 57 | + self._step = step |
42 | 58 | self._write = write |
43 | | - self._promise = None |
44 | | - if self._parallel: |
45 | | - self._worker = concurrent.futures.ThreadPoolExecutor(1, 'checkpoint') |
| 59 | + self._saveables = {} |
46 | 60 |
|
47 | 61 | def __setattr__(self, name, value): |
48 | | - if name in ('exists', 'save', 'load'): |
49 | | - return super().__setattr__(name, value) |
50 | 62 | if name.startswith('_'): |
51 | 63 | return super().__setattr__(name, value) |
52 | 64 | has_load = hasattr(value, 'load') and callable(value.load) |
53 | 65 | has_save = hasattr(value, 'save') and callable(value.save) |
54 | 66 | if not (has_load and has_save): |
55 | | - message = f"Checkpoint entry '{name}' must implement save() and load()." |
56 | | - raise ValueError(message) |
57 | | - self._values[name] = value |
| 67 | + raise ValueError( |
| 68 | + f"Checkpointed object '{name}' must implement save() and load().") |
| 69 | + self._saveables[name] = value |
58 | 70 |
|
59 | 71 | def __getattr__(self, name): |
60 | 72 | if name.startswith('_'): |
61 | 73 | raise AttributeError(name) |
62 | 74 | try: |
63 | | - return self._values[name] |
| 75 | + return self._saveables[name] |
64 | 76 | except AttributeError: |
65 | 77 | raise ValueError(name) |
66 | 78 |
|
67 | | - def exists(self, filename=None): |
68 | | - assert self._filename or filename |
69 | | - filename = path.Path(filename or self._filename) |
70 | | - exists = self._filename.exists() |
71 | | - if exists: |
| 79 | + def exists(self, path=None): |
| 80 | + assert self._directory or path |
| 81 | + if path: |
| 82 | + result = exists(path) |
| 83 | + else: |
| 84 | + result = bool(self.latest()) |
| 85 | + if result: |
72 | 86 | print('Found existing checkpoint.') |
73 | 87 | else: |
74 | 88 | print('Did not find any checkpoint.') |
75 | | - return exists |
76 | | - |
77 | | - def save(self, filename=None, keys=None): |
78 | | - assert self._filename or filename |
79 | | - filename = path.Path(filename or self._filename) |
80 | | - printing.print_(f'Saving checkpoint: {filename}') |
81 | | - keys = tuple(self._values.keys() if keys is None else keys) |
82 | | - assert all([not k.startswith('_') for k in keys]), keys |
83 | | - data = {k: self._values[k].save() for k in keys} |
84 | | - if not self._write: |
85 | | - return |
86 | | - if self._parallel: |
87 | | - self._promise and self._promise.result() |
88 | | - self._promise = self._worker.submit(self._save, filename, data) |
89 | | - else: |
90 | | - self._save(filename, data) |
| 89 | + return result |
91 | 90 |
|
92 | 91 | @timer.section('checkpoint_save') |
93 | | - def _save(self, filename, data): |
94 | | - data['_timestamp'] = time.time() |
95 | | - filename.parent.mkdir() |
96 | | - content = pickle.dumps(data) |
97 | | - if str(filename).startswith('gs://'): |
98 | | - filename.write(content, mode='wb') |
| 92 | + def save(self, path=None, keys=None): |
| 93 | + assert self._directory or path |
| 94 | + if keys is None: |
| 95 | + savefns = {k: v.save for k, v in self._saveables.items()} |
99 | 96 | else: |
100 | | - # Write to a temporary file and then atomically rename, so that the file |
101 | | - # either contains a complete write or not update at all if writing was |
102 | | - # interrupted. |
103 | | - tmp = filename.parent / (filename.name + '.tmp') |
104 | | - tmp.write(content, mode='wb') |
105 | | - tmp.move(filename) |
106 | | - print('Wrote checkpoint.') |
| 97 | + assert all([not k.startswith('_') for k in keys]), keys |
| 98 | + savefns = {k: self._saveables[k].save for k in keys} |
| 99 | + if path: |
| 100 | + folder = None |
| 101 | + else: |
| 102 | + folder = utils.timestamp(millis=True) |
| 103 | + if self._step is not None: |
| 104 | + folder += f'-{int(self._step):012d}' |
| 105 | + path = self._directory / folder |
| 106 | + printing.print_(f'Saving checkpoint: {path}') |
| 107 | + save(path, savefns, self._write) |
| 108 | + if folder and self._write: |
| 109 | + (self._directory / 'latest').write_text(folder) |
| 110 | + self._cleanup() |
| 111 | + print('Saved checkpoint.') |
107 | 112 |
|
108 | 113 | @timer.section('checkpoint_load') |
109 | | - def load(self, filename=None, keys=None): |
110 | | - assert self._filename or filename |
111 | | - self._promise and self._promise.result() # Wait for last save. |
112 | | - filename = path.Path(filename or self._filename) |
113 | | - printing.print_(f'Loading checkpoint: {filename}') |
114 | | - data = pickle.loads(filename.read('rb')) |
115 | | - keys = tuple(data.keys() if keys is None else keys) |
116 | | - for key in keys: |
117 | | - if key.startswith('_'): |
118 | | - continue |
119 | | - try: |
120 | | - self._values[key].load(data[key]) |
121 | | - except Exception: |
122 | | - print(f"Error loading '{key}' from checkpoint.") |
123 | | - raise |
124 | | - age = time.time() - data['_timestamp'] |
125 | | - printing.print_(f'Loaded checkpoint from {age:.0f} seconds ago.') |
| 114 | + def load(self, path=None, keys=None): |
| 115 | + assert self._directory or path |
| 116 | + if keys is None: |
| 117 | + loadfns = {k: v.load for k, v in self._saveables.items()} |
| 118 | + else: |
| 119 | + assert all([not k.startswith('_') for k in keys]), keys |
| 120 | + loadfns = {k: self._saveables[k].load for k in keys} |
| 121 | + if not path: |
| 122 | + path = self.latest() |
| 123 | + assert path |
| 124 | + printing.print_(f'Loading checkpoint: {path}') |
| 125 | + load(path, loadfns) |
| 126 | + print('Loaded checkpoint.') |
126 | 127 |
|
127 | 128 | def load_or_save(self): |
128 | 129 | if self.exists(): |
129 | 130 | self.load() |
130 | 131 | else: |
131 | 132 | self.save() |
| 133 | + |
| 134 | + def latest(self): |
| 135 | + filename = (self._directory / 'latest') |
| 136 | + if not filename.exists(): |
| 137 | + return None |
| 138 | + return self._directory / filename.read_text() |
| 139 | + |
| 140 | + def _cleanup(self): |
| 141 | + if not self._keep: |
| 142 | + return |
| 143 | + folders = self._directory.glob('*') |
| 144 | + folders = [x for x in folders if x.name != 'latest'] |
| 145 | + old = sorted(folders)[:-self._keep] |
| 146 | + for folder in old: |
| 147 | + folder.remove(recursive=True) |
| 148 | + |
| 149 | + |
| 150 | +def exists(path): |
| 151 | + path = pathlib.Path(path) |
| 152 | + return (path / 'done').exists() |
| 153 | + |
| 154 | + |
| 155 | +def save(path, savefns, write=True): |
| 156 | + path = pathlib.Path(path) |
| 157 | + assert not exists(path), path |
| 158 | + write and path.mkdir(parents=True) |
| 159 | + for name, savefn in savefns.items(): |
| 160 | + try: |
| 161 | + data = savefn() |
| 162 | + if inspect.isgenerator(data): |
| 163 | + for i, shard in enumerate(data): |
| 164 | + assert i < 1e5, i |
| 165 | + if write: # Iterate even if we're not writing. |
| 166 | + buffer = pickle.dumps(shard) |
| 167 | + (path / f'{name}-{i:04d}.pkl').write_bytes(buffer) |
| 168 | + else: |
| 169 | + if write: |
| 170 | + buffer = pickle.dumps(data) |
| 171 | + (path / f'{name}.pkl').write_bytes(buffer) |
| 172 | + except Exception: |
| 173 | + print(f"Error save '{name}' to checkpoint.") |
| 174 | + raise |
| 175 | + write and (path / 'done').write_bytes(b'') |
| 176 | + |
| 177 | + |
| 178 | +def load(path, loadfns): |
| 179 | + path = pathlib.Path(path) |
| 180 | + assert exists(path), path |
| 181 | + filenames = set(path.glob('*')) |
| 182 | + for name, loadfn in loadfns.items(): |
| 183 | + try: |
| 184 | + if (path / f'{name}.pkl') in filenames: |
| 185 | + buffer = (path / f'{name}.pkl').read_bytes() |
| 186 | + data = pickle.loads(buffer) |
| 187 | + loadfn(data) |
| 188 | + elif (path / f'{name}-0000.pkl') in filenames: |
| 189 | + shards = [x for x in filenames if x.name.startswith(f'{name}-')] |
| 190 | + shards = sorted(shards) |
| 191 | + def generator(): |
| 192 | + for filename in shards: |
| 193 | + buffer = filename.read_bytes() |
| 194 | + data = pickle.loads(buffer) |
| 195 | + yield data |
| 196 | + loadfn(generator()) |
| 197 | + else: |
| 198 | + raise KeyError(name) |
| 199 | + except Exception: |
| 200 | + print(f"Error loading '{name}' from checkpoint.") |
| 201 | + raise |
0 commit comments