Skip to content

Commit 2ada094

Browse files
drheadsayakpauldrhead
authored
Add extra performance features for EMAModel, torch._foreach operations and better support for non-blocking CPU offloading (huggingface#7685)
* Add support for _foreach operations and non-blocking to EMAModel * default foreach to false * add non-blocking EMA offloading to SD1.5 T2I example script * fix whitespace * move foreach to cli argument * linting * Update README.md re: EMA weight training * correct args.foreach_ema * add tests for foreach ema * code quality * add foreach to from_pretrained * default foreach false * fix linting --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: drhead <[email protected]>
1 parent f1f542b commit 2ada094

File tree

4 files changed

+226
-18
lines changed

4 files changed

+226
-18
lines changed

examples/text_to_image/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,19 @@ For our small Narutos dataset, the effects of Min-SNR weighting strategy might n
170170

171171
Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.
172172

173+
174+
#### Training with EMA weights
175+
176+
Through the `EMAModel` class, we support a convenient method of tracking an exponential moving average of model parameters. This helps to smooth out noise in model parameter updates and generally improves model performance. If enabled with the `--use_ema` argument, the final model checkpoint that is saved at the end of training will use the EMA weights.
177+
178+
EMA weights require an additional full-precision copy of the model parameters to be stored in memory, but otherwise have very little performance overhead. `--foreach_ema` can be used to further reduce the overhead. If you are short on VRAM and still want to use EMA weights, you can store them in CPU RAM by using the `--offload_ema` argument. This will keep the EMA weights in pinned CPU memory during the training step. Then, once every model parameter update, it will transfer the EMA weights back to the GPU which can then update the parameters on the GPU, before sending them back to the CPU. Both of these transfers are set up as non-blocking, so CUDA devices should be able to overlap this transfer with other computations. With sufficient bandwidth between the host and device and a sufficiently long gap between model parameter updates, storing EMA weights in CPU RAM should have no additional performance overhead, as long as no other calls force synchronization.
179+
173180
#### Training with DREAM
174181

175182
We support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://arxiv.org/abs/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop. You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper.
176183

177184

185+
178186
## Training with LoRA
179187

180188
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.

examples/text_to_image/train_text_to_image.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,8 @@ def parse_args():
387387
),
388388
)
389389
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
390+
parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.")
391+
parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.")
390392
parser.add_argument(
391393
"--non_ema_revision",
392394
type=str,
@@ -624,7 +626,12 @@ def deepspeed_zero_init_disabled_context_manager():
624626
ema_unet = UNet2DConditionModel.from_pretrained(
625627
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
626628
)
627-
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
629+
ema_unet = EMAModel(
630+
ema_unet.parameters(),
631+
model_cls=UNet2DConditionModel,
632+
model_config=ema_unet.config,
633+
foreach=args.foreach_ema,
634+
)
628635

629636
if args.enable_xformers_memory_efficient_attention:
630637
if is_xformers_available():
@@ -655,9 +662,14 @@ def save_model_hook(models, weights, output_dir):
655662

656663
def load_model_hook(models, input_dir):
657664
if args.use_ema:
658-
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
665+
load_model = EMAModel.from_pretrained(
666+
os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach_ema
667+
)
659668
ema_unet.load_state_dict(load_model.state_dict())
660-
ema_unet.to(accelerator.device)
669+
if args.offload_ema:
670+
ema_unet.pin_memory()
671+
else:
672+
ema_unet.to(accelerator.device)
661673
del load_model
662674

663675
for _ in range(len(models)):
@@ -833,7 +845,10 @@ def collate_fn(examples):
833845
)
834846

835847
if args.use_ema:
836-
ema_unet.to(accelerator.device)
848+
if args.offload_ema:
849+
ema_unet.pin_memory()
850+
else:
851+
ema_unet.to(accelerator.device)
837852

838853
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
839854
# as these weights are only used for inference, keeping weights in full precision is not required.
@@ -1011,7 +1026,11 @@ def unwrap_model(model):
10111026
# Checks if the accelerator has performed an optimization step behind the scenes
10121027
if accelerator.sync_gradients:
10131028
if args.use_ema:
1029+
if args.offload_ema:
1030+
ema_unet.to(device="cuda", non_blocking=True)
10141031
ema_unet.step(unet.parameters())
1032+
if args.offload_ema:
1033+
ema_unet.to(device="cpu", non_blocking=True)
10151034
progress_bar.update(1)
10161035
global_step += 1
10171036
accelerator.log({"train_loss": train_loss}, step=global_step)

src/diffusers/training_utils.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def __init__(
274274
use_ema_warmup: bool = False,
275275
inv_gamma: Union[float, int] = 1.0,
276276
power: Union[float, int] = 2 / 3,
277+
foreach: bool = False,
277278
model_cls: Optional[Any] = None,
278279
model_config: Dict[str, Any] = None,
279280
**kwargs,
@@ -288,6 +289,7 @@ def __init__(
288289
inv_gamma (float):
289290
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
290291
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
292+
foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
291293
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
292294
weights will be stored on CPU.
293295
@@ -342,16 +344,17 @@ def __init__(
342344
self.power = power
343345
self.optimization_step = 0
344346
self.cur_decay_value = None # set in `step()`
347+
self.foreach = foreach
345348

346349
self.model_cls = model_cls
347350
self.model_config = model_config
348351

349352
@classmethod
350-
def from_pretrained(cls, path, model_cls) -> "EMAModel":
353+
def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
351354
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
352355
model = model_cls.from_pretrained(path)
353356

354-
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
357+
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
355358

356359
ema_model.load_state_dict(ema_kwargs)
357360
return ema_model
@@ -418,15 +421,37 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
418421
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
419422
import deepspeed
420423

421-
for s_param, param in zip(self.shadow_params, parameters):
424+
if self.foreach:
422425
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
423-
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
426+
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
424427

425428
with context_manager():
426-
if param.requires_grad:
427-
s_param.sub_(one_minus_decay * (s_param - param))
428-
else:
429-
s_param.copy_(param)
429+
params_grad = [param for param in parameters if param.requires_grad]
430+
s_params_grad = [
431+
s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
432+
]
433+
434+
if len(params_grad) < len(parameters):
435+
torch._foreach_copy_(
436+
[s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad],
437+
[param for param in parameters if not param.requires_grad],
438+
non_blocking=True,
439+
)
440+
441+
torch._foreach_sub_(
442+
s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay
443+
)
444+
445+
else:
446+
for s_param, param in zip(self.shadow_params, parameters):
447+
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
448+
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
449+
450+
with context_manager():
451+
if param.requires_grad:
452+
s_param.sub_(one_minus_decay * (s_param - param))
453+
else:
454+
s_param.copy_(param)
430455

431456
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
432457
"""
@@ -438,18 +463,34 @@ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
438463
`ExponentialMovingAverage` was initialized will be used.
439464
"""
440465
parameters = list(parameters)
441-
for s_param, param in zip(self.shadow_params, parameters):
442-
param.data.copy_(s_param.to(param.device).data)
466+
if self.foreach:
467+
torch._foreach_copy_(
468+
[param.data for param in parameters],
469+
[s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)],
470+
)
471+
else:
472+
for s_param, param in zip(self.shadow_params, parameters):
473+
param.data.copy_(s_param.to(param.device).data)
443474

444-
def to(self, device=None, dtype=None) -> None:
475+
def pin_memory(self) -> None:
476+
r"""
477+
Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for
478+
offloading EMA params to the host.
479+
"""
480+
481+
self.shadow_params = [p.pin_memory() for p in self.shadow_params]
482+
483+
def to(self, device=None, dtype=None, non_blocking=False) -> None:
445484
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
446485
447486
Args:
448487
device: like `device` argument to `torch.Tensor.to`
449488
"""
450489
# .to() on the tensors handles None correctly
451490
self.shadow_params = [
452-
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
491+
p.to(device=device, dtype=dtype, non_blocking=non_blocking)
492+
if p.is_floating_point()
493+
else p.to(device=device, non_blocking=non_blocking)
453494
for p in self.shadow_params
454495
]
455496

@@ -493,8 +534,13 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
493534
"""
494535
if self.temp_stored_params is None:
495536
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
496-
for c_param, param in zip(self.temp_stored_params, parameters):
497-
param.data.copy_(c_param.data)
537+
if self.foreach:
538+
torch._foreach_copy_(
539+
[param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
540+
)
541+
else:
542+
for c_param, param in zip(self.temp_stored_params, parameters):
543+
param.data.copy_(c_param.data)
498544

499545
# Better memory-wise.
500546
self.temp_stored_params = None

tests/others/test_ema.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,138 @@ def test_serialization(self):
157157
output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample
158158

159159
assert torch.allclose(output, output_loaded, atol=1e-4)
160+
161+
162+
class EMAModelTestsForeach(unittest.TestCase):
163+
model_id = "hf-internal-testing/tiny-stable-diffusion-pipe"
164+
batch_size = 1
165+
prompt_length = 77
166+
text_encoder_hidden_dim = 32
167+
num_in_channels = 4
168+
latent_height = latent_width = 64
169+
generator = torch.manual_seed(0)
170+
171+
def get_models(self, decay=0.9999):
172+
unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet")
173+
unet = unet.to(torch_device)
174+
ema_unet = EMAModel(
175+
unet.parameters(), decay=decay, model_cls=UNet2DConditionModel, model_config=unet.config, foreach=True
176+
)
177+
return unet, ema_unet
178+
179+
def get_dummy_inputs(self):
180+
noisy_latents = torch.randn(
181+
self.batch_size, self.num_in_channels, self.latent_height, self.latent_width, generator=self.generator
182+
).to(torch_device)
183+
timesteps = torch.randint(0, 1000, size=(self.batch_size,), generator=self.generator).to(torch_device)
184+
encoder_hidden_states = torch.randn(
185+
self.batch_size, self.prompt_length, self.text_encoder_hidden_dim, generator=self.generator
186+
).to(torch_device)
187+
return noisy_latents, timesteps, encoder_hidden_states
188+
189+
def simulate_backprop(self, unet):
190+
updated_state_dict = {}
191+
for k, param in unet.state_dict().items():
192+
updated_param = torch.randn_like(param) + (param * torch.randn_like(param))
193+
updated_state_dict.update({k: updated_param})
194+
unet.load_state_dict(updated_state_dict)
195+
return unet
196+
197+
def test_optimization_steps_updated(self):
198+
unet, ema_unet = self.get_models()
199+
# Take the first (hypothetical) EMA step.
200+
ema_unet.step(unet.parameters())
201+
assert ema_unet.optimization_step == 1
202+
203+
# Take two more.
204+
for _ in range(2):
205+
ema_unet.step(unet.parameters())
206+
assert ema_unet.optimization_step == 3
207+
208+
def test_shadow_params_not_updated(self):
209+
unet, ema_unet = self.get_models()
210+
# Since the `unet` is not being updated (i.e., backprop'd)
211+
# there won't be any difference between the `params` of `unet`
212+
# and `ema_unet` even if we call `ema_unet.step(unet.parameters())`.
213+
ema_unet.step(unet.parameters())
214+
orig_params = list(unet.parameters())
215+
for s_param, param in zip(ema_unet.shadow_params, orig_params):
216+
assert torch.allclose(s_param, param)
217+
218+
# The above holds true even if we call `ema.step()` multiple times since
219+
# `unet` params are still not being updated.
220+
for _ in range(4):
221+
ema_unet.step(unet.parameters())
222+
for s_param, param in zip(ema_unet.shadow_params, orig_params):
223+
assert torch.allclose(s_param, param)
224+
225+
def test_shadow_params_updated(self):
226+
unet, ema_unet = self.get_models()
227+
# Here we simulate the parameter updates for `unet`. Since there might
228+
# be some parameters which are initialized to zero we take extra care to
229+
# initialize their values to something non-zero before the multiplication.
230+
unet_pseudo_updated_step_one = self.simulate_backprop(unet)
231+
232+
# Take the EMA step.
233+
ema_unet.step(unet_pseudo_updated_step_one.parameters())
234+
235+
# Now the EMA'd parameters won't be equal to the original model parameters.
236+
orig_params = list(unet_pseudo_updated_step_one.parameters())
237+
for s_param, param in zip(ema_unet.shadow_params, orig_params):
238+
assert ~torch.allclose(s_param, param)
239+
240+
# Ensure this is the case when we take multiple EMA steps.
241+
for _ in range(4):
242+
ema_unet.step(unet.parameters())
243+
for s_param, param in zip(ema_unet.shadow_params, orig_params):
244+
assert ~torch.allclose(s_param, param)
245+
246+
def test_consecutive_shadow_params_updated(self):
247+
# If we call EMA step after a backpropagation consecutively for two times,
248+
# the shadow params from those two steps should be different.
249+
unet, ema_unet = self.get_models()
250+
251+
# First backprop + EMA
252+
unet_step_one = self.simulate_backprop(unet)
253+
ema_unet.step(unet_step_one.parameters())
254+
step_one_shadow_params = ema_unet.shadow_params
255+
256+
# Second backprop + EMA
257+
unet_step_two = self.simulate_backprop(unet_step_one)
258+
ema_unet.step(unet_step_two.parameters())
259+
step_two_shadow_params = ema_unet.shadow_params
260+
261+
for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params):
262+
assert ~torch.allclose(step_one, step_two)
263+
264+
def test_zero_decay(self):
265+
# If there's no decay even if there are backprops, EMA steps
266+
# won't take any effect i.e., the shadow params would remain the
267+
# same.
268+
unet, ema_unet = self.get_models(decay=0.0)
269+
unet_step_one = self.simulate_backprop(unet)
270+
ema_unet.step(unet_step_one.parameters())
271+
step_one_shadow_params = ema_unet.shadow_params
272+
273+
unet_step_two = self.simulate_backprop(unet_step_one)
274+
ema_unet.step(unet_step_two.parameters())
275+
step_two_shadow_params = ema_unet.shadow_params
276+
277+
for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params):
278+
assert torch.allclose(step_one, step_two)
279+
280+
@skip_mps
281+
def test_serialization(self):
282+
unet, ema_unet = self.get_models()
283+
noisy_latents, timesteps, encoder_hidden_states = self.get_dummy_inputs()
284+
285+
with tempfile.TemporaryDirectory() as tmpdir:
286+
ema_unet.save_pretrained(tmpdir)
287+
loaded_unet = UNet2DConditionModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel)
288+
loaded_unet = loaded_unet.to(unet.device)
289+
290+
# Since no EMA step has been performed the outputs should match.
291+
output = unet(noisy_latents, timesteps, encoder_hidden_states).sample
292+
output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample
293+
294+
assert torch.allclose(output, output_loaded, atol=1e-4)

0 commit comments

Comments
 (0)