Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion improved_diffusion/dist_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,15 @@ def load_state_dict(path, **kwargs):
def sync_params(params):
"""
Synchronize a sequence of Tensors across ranks from rank 0.
"""

return 0

"""
for p in params:
with th.no_grad():
dist.broadcast(p, 0)

"""

def _find_free_port():
try:
Expand Down
42 changes: 41 additions & 1 deletion improved_diffusion/image_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,48 @@
import numpy as np
from torch.utils.data import DataLoader, Dataset

def load_dataset(
*, data_dir, batch_size, image_size, class_cond=False, deterministic=True
):
"""
For a dataset, create a generator over (images, kwargs) pairs.

Each images is an NCHW float tensor, and the kwargs dict contains zero or
more keys, each of which map to a batched Tensor of their own.
The kwargs dict can be used for class labels, in which case the key is "y"
and the values are integer tensors of class labels.

:param data_dir: a dataset directory.
:param batch_size: the batch size of each returned pair.
:param image_size: the size to which images are resized.
:param class_cond: if True, include a "y" key in returned dicts for class
label. If classes are not available and this is true, an
exception will be raised.
:param deterministic: if True, yield results in a deterministic order.
"""
if not data_dir:
raise ValueError("unspecified data directory")
all_files = _list_image_files_recursively(data_dir)
classes = None
if class_cond:
# Assume classes are the first part of the filename,
# before an underscore.
class_names = [bf.basename(path).split("_")[0] for path in all_files]
sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
classes = [sorted_classes[x] for x in class_names]
dataset = ImageDataset(
image_size,
all_files,
classes=classes,
shard=MPI.COMM_WORLD.Get_rank(),
num_shards=MPI.COMM_WORLD.Get_size(),
)

return dataset


def load_data(
*, data_dir, batch_size, image_size, class_cond=False, deterministic=False
*, data_dir, batch_size, image_size, class_cond=False, deterministic=True
):
"""
For a dataset, create a generator over (images, kwargs) pairs.
Expand All @@ -24,6 +63,7 @@ def load_data(
exception will be raised.
:param deterministic: if True, yield results in a deterministic order.
"""

if not data_dir:
raise ValueError("unspecified data directory")
all_files = _list_image_files_recursively(data_dir)
Expand Down
97 changes: 80 additions & 17 deletions improved_diffusion/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from .nn import update_ema
from .resample import LossAwareSampler, UniformSampler

import pickle
import matplotlib.pyplot as plt

from mpi4py import MPI

# For ImageNet experiments, this was a good default value.
# We found that the lg_loss_scale quickly climbed to
# 20-21 within the first ~1K steps of training.
Expand All @@ -30,6 +35,7 @@ class TrainLoop:
def __init__(
self,
*,
opt,
model,
diffusion,
data,
Expand All @@ -48,7 +54,7 @@ def __init__(
):
self.model = model
self.diffusion = diffusion
self.data = data
self.data = data
self.batch_size = batch_size
self.microbatch = microbatch if microbatch > 0 else batch_size
self.lr = lr
Expand Down Expand Up @@ -78,8 +84,10 @@ def __init__(
self._load_and_sync_parameters()
if self.use_fp16:
self._setup_fp16()

# self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
self.opt = opt

self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
if self.resume_step:
self._load_optimizer_state()
# Model was resumed, either due to a restart or a checkpoint
Expand All @@ -92,6 +100,7 @@ def __init__(
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
]

"""
if th.cuda.is_available():
self.use_ddp = True
self.ddp_model = DDP(
Expand All @@ -110,6 +119,9 @@ def __init__(
)
self.use_ddp = False
self.ddp_model = self.model
"""
self.use_ddp=False
self.ddp_model = self.model

def _load_and_sync_parameters(self):
resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
Expand Down Expand Up @@ -159,33 +171,66 @@ def _setup_fp16(self):
self.model.convert_to_fp16()

def run_loop(self):
losses, iter_times = [], []

start_event = th.cuda.Event(enable_timing=True)
stop_event = th.cuda.Event(enable_timing=True)

while (
not self.lr_anneal_steps
or self.step + self.resume_step < self.lr_anneal_steps
):
batch, cond = next(self.data)
self.run_step(batch, cond)

start_event.record()
self.run_step(batch, cond, losses)
stop_event.record()
th.cuda.synchronize()
iter_times.append(start_event.elapsed_time(stop_event))

if self.step % self.log_interval == 0:
logger.dumpkvs()

# Disabled checkpointing for now as it was causing a pickling error
# doesn't deepspeed support checkpointing too?

"""
if self.step % self.save_interval == 0:
self.save()
# Run for a finite amount of time in integration tests.
if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
return
"""
self.step += 1
# Save the last checkpoint if it wasn't already saved.
if (self.step - 1) % self.save_interval != 0:
"""
# Save the last checkpoint if it wasn't already saved.
if (self.step - 1) % self.save_interval != 0:
self.save()

def run_step(self, batch, cond):
self.forward_backward(batch, cond)
"""

if self.step % 1 == 0:
with open('iter_deepspeed' + str(MPI.COMM_WORLD.size) + '.pickle', 'wb') as handle:
pickle.dump(iter_times, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open('validation_deepspeed' + str(MPI.COMM_WORLD.size) + '.pickle', 'wb') as handle:
pickle.dump(losses, handle, protocol=pickle.HIGHEST_PROTOCOL)

plt.plot([i for i in range(len(losses))], losses)
plt.xlabel("Step")
plt.ylabel("Loss")
plt.savefig('out' + str(MPI.COMM_WORLD.size) + '.png')

def run_step(self, batch, cond, losses):
self.forward_backward(batch, cond, losses)
if self.use_fp16:
# not using fp16 rn, but assuming deepspeed takes care of all of that?
self.optimize_fp16()
else:
self.optimize_normal()
# modified this function to do model_engine.step()
self.optimize_normal()
self.log_step()

def forward_backward(self, batch, cond):
def forward_backward(self, batch, cond, loss_list):
zero_grad(self.model_params)
for i in range(0, batch.shape[0], self.microbatch):
micro = batch[i : i + self.microbatch].to(dist_util.dev())
Expand All @@ -196,19 +241,22 @@ def forward_backward(self, batch, cond):
last_batch = (i + self.microbatch) >= batch.shape[0]
t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())


# looks like training_losses method of GaussianDiffusion class takes
# care of both the forward pass and calculating the loss
compute_losses = functools.partial(
self.diffusion.training_losses,
self.ddp_model,
micro,
t,
model_kwargs=micro_cond,
)

if last_batch or not self.use_ddp:
losses = compute_losses()
else:
with self.ddp_model.no_sync():
with th.autocast(device_type="cuda", dtype=th.bfloat16):
if last_batch or not self.use_ddp:
losses = compute_losses()
else:
with self.ddp_model.no_sync():
losses = compute_losses()

if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(
Expand All @@ -219,11 +267,20 @@ def forward_backward(self, batch, cond):
log_loss_dict(
self.diffusion, t, {k: v * weights for k, v in losses.items()}
)

loss_list.append(loss.item())

"""
if self.use_fp16:
loss_scale = 2 ** self.lg_loss_scale
(loss * loss_scale).backward()
else:
loss.backward()
"""

# this is just the model_engine
# this is a wrapper over loss.backward and should take care of scaling, etc as well if needed
self.ddp_model.backward(loss)

def optimize_fp16(self):
if any(not th.isfinite(p.grad).all() for p in self.model_params):
Expand All @@ -242,11 +299,17 @@ def optimize_fp16(self):
self.lg_loss_scale += self.fp16_scale_growth

def optimize_normal(self):
self._log_grad_norm()
# .grad attributes not accessible in zero-3 model
# self._log_grad_norm()
self._anneal_lr()
self.opt.step()
# self.opt.step()
# this is just the model_engine
self.ddp_model.step()
# commenting out for now because this was erroring
"""
for rate, params in zip(self.ema_rate, self.ema_params):
update_ema(params, self.master_params, rate=rate)
"""

def _log_grad_norm(self):
sqsum = 0.0
Expand Down
Loading