Skip to content

Commit 6c15636

Browse files
authored
Add training and batched inference test for DDPM vs DDIM (#140)
* Add torch_device to the VE pipeline * Mark the training test with slow
1 parent 89f2011 commit 6c15636

File tree

4 files changed

+168
-1
lines changed

4 files changed

+168
-1
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from .optimization import (
1212
get_constant_schedule,
1313
get_constant_schedule_with_warmup,
14-
get_linear_schedule_with_warmup,
1514
get_cosine_schedule_with_warmup,
1615
get_cosine_with_hard_restarts_schedule_with_warmup,
16+
get_linear_schedule_with_warmup,
1717
get_polynomial_decay_schedule_with_warmup,
1818
get_scheduler,
1919
)

src/diffusers/training_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,44 @@
11
import copy
2+
import os
3+
import random
24

5+
import numpy as np
36
import torch
47

58

9+
def enable_full_determinism(seed: int):
10+
"""
11+
Helper function for reproducible behavior during distributed training. See
12+
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
13+
"""
14+
# set seed first
15+
set_seed(seed)
16+
17+
# Enable PyTorch deterministic mode. This potentially requires either the environment
18+
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
19+
# depending on the CUDA version, so we set them both here
20+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
21+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
22+
torch.use_deterministic_algorithms(True)
23+
24+
# Enable CUDNN deterministic mode
25+
torch.backends.cudnn.deterministic = True
26+
torch.backends.cudnn.benchmark = False
27+
28+
29+
def set_seed(seed: int):
30+
"""
31+
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
32+
Args:
33+
seed (`int`): The seed to set.
34+
"""
35+
random.seed(seed)
36+
np.random.seed(seed)
37+
torch.manual_seed(seed)
38+
torch.cuda.manual_seed_all(seed)
39+
# ^^ safe to call this function even if cuda is not available
40+
41+
642
class EMAModel:
743
"""
844
Exponential Moving Average of models weights

tests/test_modeling_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,3 +876,45 @@ def test_ldm_uncond(self):
876876
assert image.shape == (1, 256, 256, 3)
877877
expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
878878
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
879+
880+
@slow
881+
def test_ddpm_ddim_equality(self):
882+
model_id = "google/ddpm-cifar10-32"
883+
884+
unet = UNet2DModel.from_pretrained(model_id)
885+
ddpm_scheduler = DDPMScheduler(tensor_format="pt")
886+
ddim_scheduler = DDIMScheduler(tensor_format="pt")
887+
888+
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
889+
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
890+
891+
generator = torch.manual_seed(0)
892+
ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"]
893+
894+
generator = torch.manual_seed(0)
895+
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")["sample"]
896+
897+
# the values aren't exactly equal, but the images look the same upon visual inspection
898+
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
899+
900+
@slow
901+
def test_ddpm_ddim_equality_batched(self):
902+
model_id = "google/ddpm-cifar10-32"
903+
904+
unet = UNet2DModel.from_pretrained(model_id)
905+
ddpm_scheduler = DDPMScheduler(tensor_format="pt")
906+
ddim_scheduler = DDIMScheduler(tensor_format="pt")
907+
908+
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
909+
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
910+
911+
generator = torch.manual_seed(0)
912+
ddpm_images = ddpm(batch_size=2, generator=generator, output_type="numpy")["sample"]
913+
914+
generator = torch.manual_seed(0)
915+
ddim_images = ddim(batch_size=2, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
916+
"sample"
917+
]
918+
919+
# the values aren't exactly equal, but the images look the same upon visual inspection
920+
assert np.abs(ddpm_images - ddim_images).max() < 1e-1

tests/test_training.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# coding=utf-8
2+
# Copyright 2022 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import torch
19+
20+
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
21+
from diffusers.testing_utils import slow, torch_device
22+
from diffusers.training_utils import enable_full_determinism, set_seed
23+
24+
25+
torch.backends.cuda.matmul.allow_tf32 = False
26+
27+
28+
class TrainingTests(unittest.TestCase):
29+
def get_model_optimizer(self, resolution=32):
30+
set_seed(0)
31+
model = UNet2DModel(sample_size=resolution, in_channels=3, out_channels=3)
32+
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
33+
return model, optimizer
34+
35+
@slow
36+
def test_training_step_equality(self):
37+
enable_full_determinism(0)
38+
39+
ddpm_scheduler = DDPMScheduler(
40+
num_train_timesteps=1000,
41+
beta_start=0.0001,
42+
beta_end=0.02,
43+
beta_schedule="linear",
44+
clip_sample=True,
45+
tensor_format="pt",
46+
)
47+
ddim_scheduler = DDIMScheduler(
48+
num_train_timesteps=1000,
49+
beta_start=0.0001,
50+
beta_end=0.02,
51+
beta_schedule="linear",
52+
clip_sample=True,
53+
tensor_format="pt",
54+
)
55+
56+
assert ddpm_scheduler.num_train_timesteps == ddim_scheduler.num_train_timesteps
57+
58+
# shared batches for DDPM and DDIM
59+
set_seed(0)
60+
clean_images = [torch.randn((4, 3, 32, 32)).clip(-1, 1).to(torch_device) for _ in range(4)]
61+
noise = [torch.randn((4, 3, 32, 32)).to(torch_device) for _ in range(4)]
62+
timesteps = [torch.randint(0, 1000, (4,)).long().to(torch_device) for _ in range(4)]
63+
64+
# train with a DDPM scheduler
65+
model, optimizer = self.get_model_optimizer(resolution=32)
66+
model.train().to(torch_device)
67+
for i in range(4):
68+
optimizer.zero_grad()
69+
ddpm_noisy_images = ddpm_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
70+
ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i])["sample"]
71+
loss = torch.nn.functional.mse_loss(ddpm_noise_pred, noise[i])
72+
loss.backward()
73+
optimizer.step()
74+
del model, optimizer
75+
76+
# recreate the model and optimizer, and retry with DDIM
77+
model, optimizer = self.get_model_optimizer(resolution=32)
78+
model.train().to(torch_device)
79+
for i in range(4):
80+
optimizer.zero_grad()
81+
ddim_noisy_images = ddim_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
82+
ddim_noise_pred = model(ddim_noisy_images, timesteps[i])["sample"]
83+
loss = torch.nn.functional.mse_loss(ddim_noise_pred, noise[i])
84+
loss.backward()
85+
optimizer.step()
86+
del model, optimizer
87+
88+
self.assertTrue(torch.allclose(ddpm_noisy_images, ddim_noisy_images, atol=1e-5))
89+
self.assertTrue(torch.allclose(ddpm_noise_pred, ddim_noise_pred, atol=1e-5))

0 commit comments

Comments
 (0)