Skip to content

Commit 6ab29d5

Browse files
committed
start saving version of library in the checkpoints, so in case of breaking changes, researchers can return to the right working version
1 parent ddc31bc commit 6ab29d5

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
from accelerate import Accelerator
2525

26+
from denoising_diffusion_pytorch.version import __version__
27+
2628
# constants
2729

2830
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
@@ -828,7 +830,8 @@ def save(self, milestone):
828830
'model': self.accelerator.get_state_dict(self.model),
829831
'opt': self.opt.state_dict(),
830832
'ema': self.ema.state_dict(),
831-
'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
833+
'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None,
834+
'version': __version__
832835
}
833836

834837
torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
@@ -846,6 +849,9 @@ def load(self, milestone):
846849
self.opt.load_state_dict(data['opt'])
847850
self.ema.load_state_dict(data['ema'])
848851

852+
if 'version' in data:
853+
print(f"loading from version {data['version']}")
854+
849855
if exists(self.accelerator.scaler) and exists(data['scaler']):
850856
self.accelerator.scaler.load_state_dict(data['scaler'])
851857

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = '0.1.0'

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from setuptools import setup, find_packages
22

3+
exec(open('denoising_diffusion_pytorch/version.py').read())
4+
35
setup(
46
name = 'denoising-diffusion-pytorch',
57
packages = find_packages(),
6-
version = '0.32.0',
8+
version = __version__,
79
license='MIT',
810
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
911
author = 'Phil Wang',

0 commit comments

Comments
 (0)