|
1 | 1 | import math |
| 2 | +from multiprocessing import cpu_count |
| 3 | +from pathlib import Path |
2 | 4 | from random import random |
3 | 5 | from functools import partial |
4 | 6 | from collections import namedtuple |
5 | 7 |
|
6 | 8 | import torch |
| 9 | +from accelerate import Accelerator |
| 10 | +from ema_pytorch import EMA |
7 | 11 | from torch import nn, einsum |
8 | 12 | import torch.nn.functional as F |
9 | 13 |
|
10 | 14 | from einops import rearrange, reduce |
11 | 15 | from einops.layers.torch import Rearrange |
| 16 | +from torch.optim import Adam |
| 17 | +from torch.utils.data import Dataset, DataLoader |
12 | 18 |
|
13 | 19 | from tqdm.auto import tqdm |
14 | 20 |
|
| 21 | +from denoising_diffusion_pytorch.version import __version__ |
| 22 | + |
15 | 23 | # constants |
16 | 24 |
|
17 | 25 | ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) |
@@ -713,3 +721,170 @@ def forward(self, img, *args, **kwargs): |
713 | 721 |
|
714 | 722 | img = self.normalize(img) |
715 | 723 | return self.p_losses(img, t, *args, **kwargs) |
| 724 | + |
| 725 | +# trainer class |
| 726 | + |
| 727 | +class Trainer1D(object): |
| 728 | + def __init__( |
| 729 | + self, |
| 730 | + diffusion_model: GaussianDiffusion1D, |
| 731 | + dataset: Dataset, |
| 732 | + *, |
| 733 | + train_batch_size = 16, |
| 734 | + gradient_accumulate_every = 1, |
| 735 | + train_lr = 1e-4, |
| 736 | + train_num_steps = 100000, |
| 737 | + ema_update_every = 10, |
| 738 | + ema_decay = 0.995, |
| 739 | + adam_betas = (0.9, 0.99), |
| 740 | + save_and_sample_every = 1000, |
| 741 | + num_samples = 25, |
| 742 | + results_folder = './results', |
| 743 | + amp = False, |
| 744 | + fp16 = False, |
| 745 | + split_batches = True, |
| 746 | + ): |
| 747 | + super().__init__() |
| 748 | + |
| 749 | + # accelerator |
| 750 | + |
| 751 | + self.accelerator = Accelerator( |
| 752 | + split_batches = split_batches, |
| 753 | + mixed_precision = 'fp16' if fp16 else 'no' |
| 754 | + ) |
| 755 | + |
| 756 | + self.accelerator.native_amp = amp |
| 757 | + |
| 758 | + # model |
| 759 | + |
| 760 | + self.model = diffusion_model |
| 761 | + self.channels = diffusion_model.channels |
| 762 | + |
| 763 | + # sampling and training hyperparameters |
| 764 | + |
| 765 | + assert has_int_squareroot(num_samples), 'number of samples must have an integer square root' |
| 766 | + self.num_samples = num_samples |
| 767 | + self.save_and_sample_every = save_and_sample_every |
| 768 | + |
| 769 | + self.batch_size = train_batch_size |
| 770 | + self.gradient_accumulate_every = gradient_accumulate_every |
| 771 | + |
| 772 | + self.train_num_steps = train_num_steps |
| 773 | + |
| 774 | + # dataset and dataloader |
| 775 | + |
| 776 | + dl = DataLoader(dataset, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count()) |
| 777 | + |
| 778 | + dl = self.accelerator.prepare(dl) |
| 779 | + self.dl = cycle(dl) |
| 780 | + |
| 781 | + # optimizer |
| 782 | + |
| 783 | + self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas) |
| 784 | + |
| 785 | + # for logging results in a folder periodically |
| 786 | + |
| 787 | + if self.accelerator.is_main_process: |
| 788 | + self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every) |
| 789 | + self.ema.to(self.device) |
| 790 | + |
| 791 | + self.results_folder = Path(results_folder) |
| 792 | + self.results_folder.mkdir(exist_ok = True) |
| 793 | + |
| 794 | + # step counter state |
| 795 | + |
| 796 | + self.step = 0 |
| 797 | + |
| 798 | + # prepare model, dataloader, optimizer with accelerator |
| 799 | + |
| 800 | + self.model, self.opt = self.accelerator.prepare(self.model, self.opt) |
| 801 | + |
| 802 | + @property |
| 803 | + def device(self): |
| 804 | + return self.accelerator.device |
| 805 | + |
| 806 | + def save(self, milestone): |
| 807 | + if not self.accelerator.is_local_main_process: |
| 808 | + return |
| 809 | + |
| 810 | + data = { |
| 811 | + 'step': self.step, |
| 812 | + 'model': self.accelerator.get_state_dict(self.model), |
| 813 | + 'opt': self.opt.state_dict(), |
| 814 | + 'ema': self.ema.state_dict(), |
| 815 | + 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, |
| 816 | + 'version': __version__ |
| 817 | + } |
| 818 | + |
| 819 | + torch.save(data, str(self.results_folder / f'model-{milestone}.pt')) |
| 820 | + |
| 821 | + def load(self, milestone): |
| 822 | + accelerator = self.accelerator |
| 823 | + device = accelerator.device |
| 824 | + |
| 825 | + data = torch.load(str(self.results_folder / f'model-{milestone}.pt'), map_location=device) |
| 826 | + |
| 827 | + model = self.accelerator.unwrap_model(self.model) |
| 828 | + model.load_state_dict(data['model']) |
| 829 | + |
| 830 | + self.step = data['step'] |
| 831 | + self.opt.load_state_dict(data['opt']) |
| 832 | + if self.accelerator.is_main_process: |
| 833 | + self.ema.load_state_dict(data["ema"]) |
| 834 | + |
| 835 | + if 'version' in data: |
| 836 | + print(f"loading from version {data['version']}") |
| 837 | + |
| 838 | + if exists(self.accelerator.scaler) and exists(data['scaler']): |
| 839 | + self.accelerator.scaler.load_state_dict(data['scaler']) |
| 840 | + |
| 841 | + def train(self): |
| 842 | + accelerator = self.accelerator |
| 843 | + device = accelerator.device |
| 844 | + |
| 845 | + with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: |
| 846 | + |
| 847 | + while self.step < self.train_num_steps: |
| 848 | + |
| 849 | + total_loss = 0. |
| 850 | + |
| 851 | + for _ in range(self.gradient_accumulate_every): |
| 852 | + data = next(self.dl).to(device) |
| 853 | + |
| 854 | + with self.accelerator.autocast(): |
| 855 | + loss = self.model(data) |
| 856 | + loss = loss / self.gradient_accumulate_every |
| 857 | + total_loss += loss.item() |
| 858 | + |
| 859 | + self.accelerator.backward(loss) |
| 860 | + |
| 861 | + accelerator.clip_grad_norm_(self.model.parameters(), 1.0) |
| 862 | + pbar.set_description(f'loss: {total_loss:.4f}') |
| 863 | + |
| 864 | + accelerator.wait_for_everyone() |
| 865 | + |
| 866 | + self.opt.step() |
| 867 | + self.opt.zero_grad() |
| 868 | + |
| 869 | + accelerator.wait_for_everyone() |
| 870 | + |
| 871 | + self.step += 1 |
| 872 | + if accelerator.is_main_process: |
| 873 | + self.ema.update() |
| 874 | + |
| 875 | + if self.step != 0 and self.step % self.save_and_sample_every == 0: |
| 876 | + self.ema.ema_model.eval() |
| 877 | + |
| 878 | + with torch.no_grad(): |
| 879 | + milestone = self.step // self.save_and_sample_every |
| 880 | + batches = num_to_groups(self.num_samples, self.batch_size) |
| 881 | + all_samples_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches)) |
| 882 | + # |
| 883 | + all_samples = torch.cat(all_samples_list, dim = 0) |
| 884 | + # |
| 885 | + torch.save(all_samples, str(self.results_folder / f'sample-{milestone}.png')) |
| 886 | + self.save(milestone) |
| 887 | + |
| 888 | + pbar.update(1) |
| 889 | + |
| 890 | + accelerator.print('training complete') |
0 commit comments