File tree Expand file tree Collapse file tree 3 files changed +11
-2
lines changed
denoising_diffusion_pytorch Expand file tree Collapse file tree 3 files changed +11
-2
lines changed Original file line number Diff line number Diff line change 2323
2424from accelerate import Accelerator
2525
26+ from denoising_diffusion_pytorch .version import __version__
27+
2628# constants
2729
2830ModelPrediction = namedtuple ('ModelPrediction' , ['pred_noise' , 'pred_x_start' ])
@@ -828,7 +830,8 @@ def save(self, milestone):
828830 'model' : self .accelerator .get_state_dict (self .model ),
829831 'opt' : self .opt .state_dict (),
830832 'ema' : self .ema .state_dict (),
831- 'scaler' : self .accelerator .scaler .state_dict () if exists (self .accelerator .scaler ) else None
833+ 'scaler' : self .accelerator .scaler .state_dict () if exists (self .accelerator .scaler ) else None ,
834+ 'version' : __version__
832835 }
833836
834837 torch .save (data , str (self .results_folder / f'model-{ milestone } .pt' ))
@@ -846,6 +849,9 @@ def load(self, milestone):
846849 self .opt .load_state_dict (data ['opt' ])
847850 self .ema .load_state_dict (data ['ema' ])
848851
852+ if 'version' in data :
853+ print (f"loading from version { data ['version' ]} " )
854+
849855 if exists (self .accelerator .scaler ) and exists (data ['scaler' ]):
850856 self .accelerator .scaler .load_state_dict (data ['scaler' ])
851857
Original file line number Diff line number Diff line change 1+ __version__ = '0.1.0'
Original file line number Diff line number Diff line change 11from setuptools import setup , find_packages
22
3+ exec (open ('denoising_diffusion_pytorch/version.py' ).read ())
4+
35setup (
46 name = 'denoising-diffusion-pytorch' ,
57 packages = find_packages (),
6- version = '0.32.0' ,
8+ version = __version__ ,
79 license = 'MIT' ,
810 description = 'Denoising Diffusion Probabilistic Models - Pytorch' ,
911 author = 'Phil Wang' ,
You can’t perform that action at this time.
0 commit comments