Skip to content

Commit 8742e4e

Browse files
committed
Add review suggestions
1 parent a32b869 commit 8742e4e

File tree

3 files changed

+13
-57
lines changed

3 files changed

+13
-57
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
accelerate>=0.16.0
22
torchvision
33
datasets
4+
wandb
5+
tensrboard

examples/consistency_models/script.sh

Lines changed: 0 additions & 3 deletions
This file was deleted.

examples/consistency_models/train_consistency_distillation.py

Lines changed: 11 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from packaging import version
1919
from torchvision import transforms
2020
from tqdm.auto import tqdm
21-
21+
import wandb
2222
import diffusers
2323
from diffusers import DDPMPipeline, UNet2DModel, CMStochasticIterativeScheduler, ConsistencyModelPipeline
2424
from diffusers.optimization import get_scheduler
@@ -33,35 +33,6 @@
3333

3434
logger = get_logger(__name__, log_level="INFO")
3535

36-
def append_dims(x, target_dims):
37-
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
38-
dims_to_append = target_dims - x.ndim
39-
if dims_to_append < 0:
40-
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
41-
return x[(...,) + (None,) * dims_to_append]
42-
43-
44-
45-
46-
47-
48-
def _extract_into_tensor(arr, timesteps, broadcast_shape):
49-
"""
50-
Extract values from a 1-D numpy array for a batch of indices.
51-
52-
:param arr: the 1-D numpy array.
53-
:param timesteps: a tensor of indices into the array to extract.
54-
:param broadcast_shape: a larger shape of K dimensions with the batch
55-
dimension equal to the length of timesteps.
56-
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
57-
"""
58-
if not isinstance(arr, torch.Tensor):
59-
arr = torch.from_numpy(arr)
60-
res = arr[timesteps].float().to(timesteps.device)
61-
while len(res.shape) < len(broadcast_shape):
62-
res = res[..., None]
63-
return res.expand(broadcast_shape)
64-
6536

6637
def parse_args():
6738
parser = argparse.ArgumentParser(description="Simple example of a training script.")
@@ -290,15 +261,6 @@ def main(args):
290261
project_config=accelerator_project_config,
291262
)
292263

293-
if args.logger == "tensorboard":
294-
if not is_tensorboard_available():
295-
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
296-
297-
elif args.logger == "wandb":
298-
if not is_wandb_available():
299-
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
300-
import wandb
301-
302264
# `accelerate` 0.16.0 will have better support for customized saving
303265
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
304266
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -413,9 +375,6 @@ def load_model_hook(models, input_dir):
413375

414376
# load the model to distill into a consistency model
415377
teacher_model = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").unet
416-
model = model.float()
417-
target_model = target_model.float() # TODO : support half precision training
418-
teacher_model = teacher_model.float()
419378
noise_scheduler = CMStochasticIterativeScheduler()
420379
num_scales = 40
421380

@@ -586,24 +545,22 @@ def transform_images(examples):
586545
# TODO - make this cleaner
587546
samples = noised_image
588547
x = samples
589-
model_output = teacher_model(noise_scheduler.scale_model_input(x, timestep), timestep, class_labels=labels).sample
548+
teacher_model_output = teacher_model(noise_scheduler.scale_model_input(x, timestep), timestep, class_labels=labels).sample
590549
teacher_denoiser = noise_scheduler.step(
591-
model_output, timestep, x, use_noise=False
550+
teacher_model_output, timestep, x, use_noise=False
592551
).prev_sample
593-
d = (x - teacher_denoiser) / append_dims(sigma, x.ndim)
594-
samples = x + d * append_dims(sigma_prev - sigma, x.ndim)
595-
model_output = teacher_model(noise_scheduler.scale_model_input(samples, timestep_prev), timestep_prev, class_labels=labels).sample
552+
d = (x - teacher_denoiser) / sigma[(...,) + (None,) * 3]
553+
samples = x + d * (sigma_prev - sigma)[(...,) + (None,) * 3]
554+
teacher_model_output = teacher_model(noise_scheduler.scale_model_input(samples, timestep_prev), timestep_prev, class_labels=labels).sample
596555
teacher_denoiser = noise_scheduler.step(
597-
model_output, timestep_prev, samples, use_noise=False
556+
teacher_model_output, timestep_prev, samples, use_noise=False
598557
).prev_sample
599-
600-
next_d = (samples - teacher_denoiser) / append_dims(sigma_prev, x.ndim)
601-
denoised_image = x + (d + next_d) * append_dims((sigma_prev - sigma) /2, x.ndim)
602-
558+
next_d = (samples - teacher_denoiser) / sigma_prev[(...,) + (None,) * 3]
559+
denoised_image = x + (d + next_d) * ((sigma_prev - sigma) /2)[(...,) + (None,) * 3]
603560
# get output from target model
604-
model_output = target_model(noise_scheduler.scale_model_input(denoised_image, timestep_prev), timestep_prev, class_labels=labels).sample
561+
target_model_output = target_model(noise_scheduler.scale_model_input(denoised_image, timestep_prev), timestep_prev, class_labels=labels).sample
605562
distiller_target = noise_scheduler.step(
606-
model_output, timestep_prev, denoised_image, use_noise=False
563+
target_model_output, timestep_prev, denoised_image, use_noise=False
607564
).prev_sample
608565

609566
loss = F.mse_loss(distiller, distiller_target)

0 commit comments

Comments
 (0)