Skip to content

Commit 3e0cd93

Browse files
committed
Add gradient checkpointing
1 parent 2c0b2a7 commit 3e0cd93

File tree

3 files changed

+55
-19
lines changed

3 files changed

+55
-19
lines changed

examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,13 @@
2121
from torchvision import transforms
2222
from torchvision.transforms.functional import crop
2323
from transformers import CLIPTextModel, CLIPTextModelWithProjection, AutoTokenizer
24+
from transformers.trainer_pt_utils import get_module_class_from_name
2425
from viztracer import VizTracer
2526

27+
from torch._dispatch.python import suspend_functionalization
28+
from torch._subclasses.functional_tensor import disable_functional_mode
29+
30+
from torch_xla.distributed.fsdp import checkpoint_module
2631
from diffusers import (
2732
AutoencoderKL,
2833
DDPMScheduler,
@@ -118,6 +123,35 @@ def main(args):
118123
model_card.save(os.path.join(repo_folder, "README.md"))
119124

120125

126+
def wrap_module(
127+
mod: torch.nn.Module, transform, prefix: tuple[str, ...] = tuple()
128+
) -> torch.nn.Module:
129+
"""
130+
Recursively transforms the modules by calling `transform` on them.
131+
132+
You may use this to apply sharding, checkpointing, optimization barriers, etc.
133+
134+
Start from the leaf modules and work our way up, to handle cases where one
135+
module is the child of another. The child modules will be transformed first,
136+
and then the parent module will be transformed, possibly with transformed
137+
children.
138+
"""
139+
new_children = {}
140+
for name, child in mod.named_children():
141+
new_children[name] = wrap_module(child, transform, prefix + (name,))
142+
for name, new_child in new_children.items():
143+
mod.set_submodule(name, new_child)
144+
return transform(mod)
145+
146+
def add_checkpoints(model):
147+
remat_classes = [get_module_class_from_name(model, "BasicTransformerBlock")]
148+
import pdb; pdb.set_trace()
149+
def maybe_checkpoint(mod):
150+
if isinstance(mod, tuple(remat_classes)):
151+
return checkpoint_module(mod)
152+
return mod
153+
return wrap_module(model, maybe_checkpoint)
154+
121155
class TrainSD:
122156
def __init__(
123157
self,
@@ -163,13 +197,14 @@ def start_training(self):
163197
tracer = VizTracer()
164198
else:
165199
tracer = None
166-
loss = self.step_fn(
167-
tracer,
168-
batch["model_input"],
169-
batch["prompt_embeds"],
170-
batch["pooled_prompt_embeds"],
171-
batch["original_sizes"],
172-
batch["crop_top_lefts"])
200+
with suspend_functionalization(), disable_functional_mode():
201+
loss = self.step_fn(
202+
tracer,
203+
batch["model_input"],
204+
batch["prompt_embeds"],
205+
batch["pooled_prompt_embeds"],
206+
batch["original_sizes"],
207+
batch["crop_top_lefts"])
173208
self.global_step += 1
174209

175210
def print_loss_closure(step, loss):
@@ -647,9 +682,9 @@ def main(args):
647682
use_fast=False
648683
)
649684

650-
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
685+
# from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
651686

652-
unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
687+
# unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
653688
unet.enable_xla_flash_attention(partition_spec=("data", None, None, None))
654689

655690
vae.requires_grad_(False)
@@ -810,6 +845,8 @@ def collate_fn(examples):
810845
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
811846
)
812847
print(f" Total optimization steps = {args.max_train_steps}")
848+
849+
unet = add_checkpoints(unet)
813850

814851
trainer = TrainSD(
815852
weight_dtype=weight_dtype,

src/diffusers/models/attention_processor.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3390,12 +3390,12 @@ def scaled_dot_product_attention_jax(query, key, value):
33903390
# x = wrapped_attention(query, key, value)
33913391
# return x
33923392

3393-
@functools.lru_cache(maxsize=16)
3393+
@functools.lru_cache(maxsize=256)
33943394
def _get_jax_forward_function():
33953395
"""Cached factory function to create JAX forward functions"""
33963396
return scaled_dot_product_attention_jax
33973397

3398-
@functools.lru_cache(maxsize=16)
3398+
@functools.lru_cache(maxsize=256)
33993399
def _get_jax_backward_function():
34003400
"""Cached factory function to create JAX backward functions"""
34013401
jax_f = _get_jax_forward_function()
@@ -3419,14 +3419,12 @@ def scaled_dot_product_attention_jax_wrapper(query, key, value, grad_output=None
34193419
class JaxFun(torch.autograd.Function):
34203420
@staticmethod
34213421
def forward(ctx, query, key, value):
3422-
# sample_inputs = [abstractify(query), abstractify(key), abstractify(value)]
34233422
ctx.save_for_backward(query, key, value)
34243423
out = scaled_dot_product_attention_jax_wrapper(query, key, value)
34253424
return out
34263425

34273426
@staticmethod
34283427
def backward(ctx, grad_out):
3429-
# import pdb; pdb.set_trace()
34303428
query, key, value = ctx.saved_tensors
34313429
q_grad, k_grad, v_grad = scaled_dot_product_attention_jax_wrapper(query, key, value, grad_output=grad_out, is_forward=False)
34323430
return q_grad, k_grad, v_grad
@@ -3453,6 +3451,7 @@ def flash_attention(self, query, key, value):
34533451
p = self.partition_spec if is_spmd() else None
34543452
return flash_attention(query, key, value, causal=False, partition_spec=p)
34553453

3454+
@xp.trace_me("scaled_dot_product_attention")
34563455
def scaled_dot_product_attention(self, query, key, value) -> torch.Tensor:
34573456
scale_factor = 1 / math.sqrt(query.size(-1))
34583457
attn_weight = query @ key.transpose(-2, -1) * scale_factor
@@ -3537,10 +3536,10 @@ def __call__(
35373536
# logger.warning(
35383537
# "Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
35393538
# )
3540-
# hidden_states = self.scaled_dot_product_attention(
3541-
# query, key, value
3542-
# )
3543-
hidden_states = JaxFun.apply(query, key, value)
3539+
hidden_states = self.scaled_dot_product_attention(
3540+
query, key, value
3541+
)
3542+
# hidden_states = JaxFun.apply(query, key, value)
35443543

35453544
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
35463545
hidden_states = hidden_states.to(query.dtype)

train.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
export XLA_DISABLE_FUNCTIONALIZATION=1
2-
export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1
2+
# export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1
33
export PROFILE_DIR=/mnt/bbahl/xla_profile/
44
export CACHE_DIR=/mnt/bbahl/xla_cache/
55
export DATASET_NAME=lambdalabs/naruto-blip-captions
6-
export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
6+
export PER_HOST_BATCH_SIZE=40 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
77
export TRAIN_STEPS=50
88
export PROFILE_START_STEP=10
99
export OUTPUT_DIR=/tmp/trained-model/

0 commit comments

Comments
 (0)