|
18 | 18 | from packaging import version |
19 | 19 | from torchvision import transforms |
20 | 20 | from tqdm.auto import tqdm |
21 | | - |
| 21 | +import wandb |
22 | 22 | import diffusers |
23 | 23 | from diffusers import DDPMPipeline, UNet2DModel, CMStochasticIterativeScheduler, ConsistencyModelPipeline |
24 | 24 | from diffusers.optimization import get_scheduler |
|
33 | 33 |
|
34 | 34 | logger = get_logger(__name__, log_level="INFO") |
35 | 35 |
|
36 | | -def append_dims(x, target_dims): |
37 | | - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
38 | | - dims_to_append = target_dims - x.ndim |
39 | | - if dims_to_append < 0: |
40 | | - raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") |
41 | | - return x[(...,) + (None,) * dims_to_append] |
42 | | - |
43 | | - |
44 | | - |
45 | | - |
46 | | - |
47 | | - |
48 | | -def _extract_into_tensor(arr, timesteps, broadcast_shape): |
49 | | - """ |
50 | | - Extract values from a 1-D numpy array for a batch of indices. |
51 | | -
|
52 | | - :param arr: the 1-D numpy array. |
53 | | - :param timesteps: a tensor of indices into the array to extract. |
54 | | - :param broadcast_shape: a larger shape of K dimensions with the batch |
55 | | - dimension equal to the length of timesteps. |
56 | | - :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. |
57 | | - """ |
58 | | - if not isinstance(arr, torch.Tensor): |
59 | | - arr = torch.from_numpy(arr) |
60 | | - res = arr[timesteps].float().to(timesteps.device) |
61 | | - while len(res.shape) < len(broadcast_shape): |
62 | | - res = res[..., None] |
63 | | - return res.expand(broadcast_shape) |
64 | | - |
65 | 36 |
|
66 | 37 | def parse_args(): |
67 | 38 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
@@ -290,15 +261,6 @@ def main(args): |
290 | 261 | project_config=accelerator_project_config, |
291 | 262 | ) |
292 | 263 |
|
293 | | - if args.logger == "tensorboard": |
294 | | - if not is_tensorboard_available(): |
295 | | - raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.") |
296 | | - |
297 | | - elif args.logger == "wandb": |
298 | | - if not is_wandb_available(): |
299 | | - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") |
300 | | - import wandb |
301 | | - |
302 | 264 | # `accelerate` 0.16.0 will have better support for customized saving |
303 | 265 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): |
304 | 266 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format |
@@ -413,9 +375,6 @@ def load_model_hook(models, input_dir): |
413 | 375 |
|
414 | 376 | # load the model to distill into a consistency model |
415 | 377 | teacher_model = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").unet |
416 | | - model = model.float() |
417 | | - target_model = target_model.float() # TODO : support half precision training |
418 | | - teacher_model = teacher_model.float() |
419 | 378 | noise_scheduler = CMStochasticIterativeScheduler() |
420 | 379 | num_scales = 40 |
421 | 380 |
|
@@ -586,24 +545,22 @@ def transform_images(examples): |
586 | 545 | # TODO - make this cleaner |
587 | 546 | samples = noised_image |
588 | 547 | x = samples |
589 | | - model_output = teacher_model(noise_scheduler.scale_model_input(x, timestep), timestep, class_labels=labels).sample |
| 548 | + teacher_model_output = teacher_model(noise_scheduler.scale_model_input(x, timestep), timestep, class_labels=labels).sample |
590 | 549 | teacher_denoiser = noise_scheduler.step( |
591 | | - model_output, timestep, x, use_noise=False |
| 550 | + teacher_model_output, timestep, x, use_noise=False |
592 | 551 | ).prev_sample |
593 | | - d = (x - teacher_denoiser) / append_dims(sigma, x.ndim) |
594 | | - samples = x + d * append_dims(sigma_prev - sigma, x.ndim) |
595 | | - model_output = teacher_model(noise_scheduler.scale_model_input(samples, timestep_prev), timestep_prev, class_labels=labels).sample |
| 552 | + d = (x - teacher_denoiser) / sigma[(...,) + (None,) * 3] |
| 553 | + samples = x + d * (sigma_prev - sigma)[(...,) + (None,) * 3] |
| 554 | + teacher_model_output = teacher_model(noise_scheduler.scale_model_input(samples, timestep_prev), timestep_prev, class_labels=labels).sample |
596 | 555 | teacher_denoiser = noise_scheduler.step( |
597 | | - model_output, timestep_prev, samples, use_noise=False |
| 556 | + teacher_model_output, timestep_prev, samples, use_noise=False |
598 | 557 | ).prev_sample |
599 | | - |
600 | | - next_d = (samples - teacher_denoiser) / append_dims(sigma_prev, x.ndim) |
601 | | - denoised_image = x + (d + next_d) * append_dims((sigma_prev - sigma) /2, x.ndim) |
602 | | - |
| 558 | + next_d = (samples - teacher_denoiser) / sigma_prev[(...,) + (None,) * 3] |
| 559 | + denoised_image = x + (d + next_d) * ((sigma_prev - sigma) /2)[(...,) + (None,) * 3] |
603 | 560 | # get output from target model |
604 | | - model_output = target_model(noise_scheduler.scale_model_input(denoised_image, timestep_prev), timestep_prev, class_labels=labels).sample |
| 561 | + target_model_output = target_model(noise_scheduler.scale_model_input(denoised_image, timestep_prev), timestep_prev, class_labels=labels).sample |
605 | 562 | distiller_target = noise_scheduler.step( |
606 | | - model_output, timestep_prev, denoised_image, use_noise=False |
| 563 | + target_model_output, timestep_prev, denoised_image, use_noise=False |
607 | 564 | ).prev_sample |
608 | 565 |
|
609 | 566 | loss = F.mse_loss(distiller, distiller_target) |
|
0 commit comments