Skip to content

Commit aa201af

Browse files
authored
Merge pull request #204 from klasocki/main
Add Trainer1D for 1D diffusion
2 parents e3c8036 + 1aa21dd commit aa201af

File tree

3 files changed

+197
-6
lines changed

3 files changed

+197
-6
lines changed

README.md

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,10 @@ $ accelerate launch train.py
106106

107107
### 1D Sequence
108108

109-
By popular request, a 1D Unet + Gaussian Diffusion implementation. You will have to do the training code yourself
110-
109+
By popular request, a 1D Unet + Gaussian Diffusion implementation.
111110
```python
112111
import torch
113-
from denoising_diffusion_pytorch import Unet1D, GaussianDiffusion1D
112+
from denoising_diffusion_pytorch import Unet1D, GaussianDiffusion1D, Trainer1D
114113

115114
model = Unet1D(
116115
dim = 64,
@@ -125,16 +124,33 @@ diffusion = GaussianDiffusion1D(
125124
objective = 'pred_v'
126125
)
127126

128-
training_seq = torch.rand(8, 32, 128) # features are normalized from 0 to 1
127+
training_seq = torch.rand(64, 32, 128) # features are normalized from 0 to 1
129128
loss = diffusion(training_seq)
130129
loss.backward()
131130

131+
# Or using trainer
132+
133+
trainer = Trainer1D(
134+
diffusion,
135+
dataset = training_seq,
136+
train_batch_size = 32,
137+
train_lr = 8e-5,
138+
train_num_steps = 700000, # total training steps
139+
gradient_accumulate_every = 2, # gradient accumulation steps
140+
ema_decay = 0.995, # exponential moving average decay
141+
amp = True, # turn on mixed precision
142+
)
143+
trainer.train()
144+
132145
# after a lot of training
133146

134147
sampled_seq = diffusion.sample(batch_size = 4)
135148
sampled_seq.shape # (4, 32, 128)
136-
```
137149

150+
```
151+
`Trainer1D` does not evaluate the generated samples in any way since the type of data is not known.
152+
You could consider adding a suitable metric to the training loop yourself after doing an editable install of this package
153+
`pip install -e .`.
138154
## Citations
139155

140156
```bibtex

denoising_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
from denoising_diffusion_pytorch.elucidated_diffusion import ElucidatedDiffusion
77
from denoising_diffusion_pytorch.v_param_continuous_time_gaussian_diffusion import VParamContinuousTimeGaussianDiffusion
88

9-
from denoising_diffusion_pytorch.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Unet1D
9+
from denoising_diffusion_pytorch.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Unet1D, Trainer1D
1010

denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
import math
2+
from multiprocessing import cpu_count
3+
from pathlib import Path
24
from random import random
35
from functools import partial
46
from collections import namedtuple
57

68
import torch
9+
from accelerate import Accelerator
10+
from ema_pytorch import EMA
711
from torch import nn, einsum
812
import torch.nn.functional as F
913

1014
from einops import rearrange, reduce
1115
from einops.layers.torch import Rearrange
16+
from torch.optim import Adam
17+
from torch.utils.data import Dataset, DataLoader
1218

1319
from tqdm.auto import tqdm
1420

21+
from denoising_diffusion_pytorch.version import __version__
22+
1523
# constants
1624

1725
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
@@ -713,3 +721,170 @@ def forward(self, img, *args, **kwargs):
713721

714722
img = self.normalize(img)
715723
return self.p_losses(img, t, *args, **kwargs)
724+
725+
# trainer class
726+
727+
class Trainer1D(object):
728+
def __init__(
729+
self,
730+
diffusion_model: GaussianDiffusion1D,
731+
dataset: Dataset,
732+
*,
733+
train_batch_size = 16,
734+
gradient_accumulate_every = 1,
735+
train_lr = 1e-4,
736+
train_num_steps = 100000,
737+
ema_update_every = 10,
738+
ema_decay = 0.995,
739+
adam_betas = (0.9, 0.99),
740+
save_and_sample_every = 1000,
741+
num_samples = 25,
742+
results_folder = './results',
743+
amp = False,
744+
fp16 = False,
745+
split_batches = True,
746+
):
747+
super().__init__()
748+
749+
# accelerator
750+
751+
self.accelerator = Accelerator(
752+
split_batches = split_batches,
753+
mixed_precision = 'fp16' if fp16 else 'no'
754+
)
755+
756+
self.accelerator.native_amp = amp
757+
758+
# model
759+
760+
self.model = diffusion_model
761+
self.channels = diffusion_model.channels
762+
763+
# sampling and training hyperparameters
764+
765+
assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
766+
self.num_samples = num_samples
767+
self.save_and_sample_every = save_and_sample_every
768+
769+
self.batch_size = train_batch_size
770+
self.gradient_accumulate_every = gradient_accumulate_every
771+
772+
self.train_num_steps = train_num_steps
773+
774+
# dataset and dataloader
775+
776+
dl = DataLoader(dataset, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())
777+
778+
dl = self.accelerator.prepare(dl)
779+
self.dl = cycle(dl)
780+
781+
# optimizer
782+
783+
self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)
784+
785+
# for logging results in a folder periodically
786+
787+
if self.accelerator.is_main_process:
788+
self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)
789+
self.ema.to(self.device)
790+
791+
self.results_folder = Path(results_folder)
792+
self.results_folder.mkdir(exist_ok = True)
793+
794+
# step counter state
795+
796+
self.step = 0
797+
798+
# prepare model, dataloader, optimizer with accelerator
799+
800+
self.model, self.opt = self.accelerator.prepare(self.model, self.opt)
801+
802+
@property
803+
def device(self):
804+
return self.accelerator.device
805+
806+
def save(self, milestone):
807+
if not self.accelerator.is_local_main_process:
808+
return
809+
810+
data = {
811+
'step': self.step,
812+
'model': self.accelerator.get_state_dict(self.model),
813+
'opt': self.opt.state_dict(),
814+
'ema': self.ema.state_dict(),
815+
'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None,
816+
'version': __version__
817+
}
818+
819+
torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
820+
821+
def load(self, milestone):
822+
accelerator = self.accelerator
823+
device = accelerator.device
824+
825+
data = torch.load(str(self.results_folder / f'model-{milestone}.pt'), map_location=device)
826+
827+
model = self.accelerator.unwrap_model(self.model)
828+
model.load_state_dict(data['model'])
829+
830+
self.step = data['step']
831+
self.opt.load_state_dict(data['opt'])
832+
if self.accelerator.is_main_process:
833+
self.ema.load_state_dict(data["ema"])
834+
835+
if 'version' in data:
836+
print(f"loading from version {data['version']}")
837+
838+
if exists(self.accelerator.scaler) and exists(data['scaler']):
839+
self.accelerator.scaler.load_state_dict(data['scaler'])
840+
841+
def train(self):
842+
accelerator = self.accelerator
843+
device = accelerator.device
844+
845+
with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar:
846+
847+
while self.step < self.train_num_steps:
848+
849+
total_loss = 0.
850+
851+
for _ in range(self.gradient_accumulate_every):
852+
data = next(self.dl).to(device)
853+
854+
with self.accelerator.autocast():
855+
loss = self.model(data)
856+
loss = loss / self.gradient_accumulate_every
857+
total_loss += loss.item()
858+
859+
self.accelerator.backward(loss)
860+
861+
accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
862+
pbar.set_description(f'loss: {total_loss:.4f}')
863+
864+
accelerator.wait_for_everyone()
865+
866+
self.opt.step()
867+
self.opt.zero_grad()
868+
869+
accelerator.wait_for_everyone()
870+
871+
self.step += 1
872+
if accelerator.is_main_process:
873+
self.ema.update()
874+
875+
if self.step != 0 and self.step % self.save_and_sample_every == 0:
876+
self.ema.ema_model.eval()
877+
878+
with torch.no_grad():
879+
milestone = self.step // self.save_and_sample_every
880+
batches = num_to_groups(self.num_samples, self.batch_size)
881+
all_samples_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
882+
#
883+
all_samples = torch.cat(all_samples_list, dim = 0)
884+
#
885+
torch.save(all_samples, str(self.results_folder / f'sample-{milestone}.png'))
886+
self.save(milestone)
887+
888+
pbar.update(1)
889+
890+
accelerator.print('training complete')

0 commit comments

Comments
 (0)