Skip to content

Commit 0a18325

Browse files
committed
fix some issues with ddpm 1d
1 parent ea5713e commit 0a18325

File tree

4 files changed

+25
-13
lines changed

4 files changed

+25
-13
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ diffusion = GaussianDiffusion(
4545
training_images = torch.rand(8, 3, 128, 128) # images are normalized from 0 to 1
4646
loss = diffusion(training_images)
4747
loss.backward()
48+
4849
# after a lot of training
4950

5051
sampled_images = diffusion.sample(batch_size = 4)

denoising_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,4 @@
66
from denoising_diffusion_pytorch.elucidated_diffusion import ElucidatedDiffusion
77
from denoising_diffusion_pytorch.v_param_continuous_time_gaussian_diffusion import VParamContinuousTimeGaussianDiffusion
88

9-
from denoising_diffusion_pytorch.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Unet1D, Trainer1D
10-
9+
from denoising_diffusion_pytorch.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Unet1D, Trainer1D, Dataset1D

denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
import math
2-
from multiprocessing import cpu_count
32
from pathlib import Path
43
from random import random
54
from functools import partial
65
from collections import namedtuple
6+
from multiprocessing import cpu_count
77

88
import torch
9-
from accelerate import Accelerator
10-
from ema_pytorch import EMA
11-
from torch import nn, einsum
9+
from torch import nn, einsum, Tensor
1210
import torch.nn.functional as F
1311
from torch.cuda.amp import autocast
12+
from torch.optim import Adam
13+
from torch.utils.data import Dataset, DataLoader
1414

1515
from einops import rearrange, reduce
1616
from einops.layers.torch import Rearrange
17-
from torch.optim import Adam
18-
from torch.utils.data import Dataset, DataLoader
17+
18+
from accelerate import Accelerator
19+
from ema_pytorch import EMA
1920

2021
from tqdm.auto import tqdm
2122

@@ -67,6 +68,19 @@ def normalize_to_neg_one_to_one(img):
6768
def unnormalize_to_zero_to_one(t):
6869
return (t + 1) * 0.5
6970

71+
# data
72+
73+
class Dataset1D(Dataset):
74+
def __init__(self, tensor: Tensor):
75+
super().__init__()
76+
self.tensor = tensor.clone()
77+
78+
def __len__(self):
79+
return len(self.tensor)
80+
81+
def __getitem__(self, idx):
82+
return self.tensor[idx].clone()
83+
7084
# small helper modules
7185

7286
class Residual(nn.Module):
@@ -714,7 +728,7 @@ def __init__(
714728
num_samples = 25,
715729
results_folder = './results',
716730
amp = False,
717-
fp16 = False,
731+
mixed_precision_type = 'fp16',
718732
split_batches = True,
719733
):
720734
super().__init__()
@@ -723,11 +737,9 @@ def __init__(
723737

724738
self.accelerator = Accelerator(
725739
split_batches = split_batches,
726-
mixed_precision = 'fp16' if fp16 else 'no'
740+
mixed_precision = mixed_precision_type if amp else 'no'
727741
)
728742

729-
self.accelerator.native_amp = amp
730-
731743
# model
732744

733745
self.model = diffusion_model
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.8.2'
1+
__version__ = '1.8.3'

0 commit comments

Comments
 (0)