|
21 | 21 | from torchvision import transforms |
22 | 22 | from torchvision.transforms.functional import crop |
23 | 23 | from transformers import CLIPTextModel, CLIPTextModelWithProjection, AutoTokenizer |
| 24 | +from transformers.trainer_pt_utils import get_module_class_from_name |
24 | 25 | from viztracer import VizTracer |
25 | 26 |
|
| 27 | +from torch._dispatch.python import suspend_functionalization |
| 28 | +from torch._subclasses.functional_tensor import disable_functional_mode |
| 29 | + |
| 30 | +from torch_xla.distributed.fsdp import checkpoint_module |
26 | 31 | from diffusers import ( |
27 | 32 | AutoencoderKL, |
28 | 33 | DDPMScheduler, |
@@ -118,6 +123,35 @@ def main(args): |
118 | 123 | model_card.save(os.path.join(repo_folder, "README.md")) |
119 | 124 |
|
120 | 125 |
|
| 126 | +def wrap_module( |
| 127 | + mod: torch.nn.Module, transform, prefix: tuple[str, ...] = tuple() |
| 128 | +) -> torch.nn.Module: |
| 129 | + """ |
| 130 | + Recursively transforms the modules by calling `transform` on them. |
| 131 | +
|
| 132 | + You may use this to apply sharding, checkpointing, optimization barriers, etc. |
| 133 | +
|
| 134 | + Start from the leaf modules and work our way up, to handle cases where one |
| 135 | + module is the child of another. The child modules will be transformed first, |
| 136 | + and then the parent module will be transformed, possibly with transformed |
| 137 | + children. |
| 138 | + """ |
| 139 | + new_children = {} |
| 140 | + for name, child in mod.named_children(): |
| 141 | + new_children[name] = wrap_module(child, transform, prefix + (name,)) |
| 142 | + for name, new_child in new_children.items(): |
| 143 | + mod.set_submodule(name, new_child) |
| 144 | + return transform(mod) |
| 145 | + |
| 146 | +def add_checkpoints(model): |
| 147 | + remat_classes = [get_module_class_from_name(model, "BasicTransformerBlock")] |
| 148 | + import pdb; pdb.set_trace() |
| 149 | + def maybe_checkpoint(mod): |
| 150 | + if isinstance(mod, tuple(remat_classes)): |
| 151 | + return checkpoint_module(mod) |
| 152 | + return mod |
| 153 | + return wrap_module(model, maybe_checkpoint) |
| 154 | + |
121 | 155 | class TrainSD: |
122 | 156 | def __init__( |
123 | 157 | self, |
@@ -163,13 +197,14 @@ def start_training(self): |
163 | 197 | tracer = VizTracer() |
164 | 198 | else: |
165 | 199 | tracer = None |
166 | | - loss = self.step_fn( |
167 | | - tracer, |
168 | | - batch["model_input"], |
169 | | - batch["prompt_embeds"], |
170 | | - batch["pooled_prompt_embeds"], |
171 | | - batch["original_sizes"], |
172 | | - batch["crop_top_lefts"]) |
| 200 | + with suspend_functionalization(), disable_functional_mode(): |
| 201 | + loss = self.step_fn( |
| 202 | + tracer, |
| 203 | + batch["model_input"], |
| 204 | + batch["prompt_embeds"], |
| 205 | + batch["pooled_prompt_embeds"], |
| 206 | + batch["original_sizes"], |
| 207 | + batch["crop_top_lefts"]) |
173 | 208 | self.global_step += 1 |
174 | 209 |
|
175 | 210 | def print_loss_closure(step, loss): |
@@ -647,9 +682,9 @@ def main(args): |
647 | 682 | use_fast=False |
648 | 683 | ) |
649 | 684 |
|
650 | | - from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear |
| 685 | + # from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear |
651 | 686 |
|
652 | | - unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward) |
| 687 | + # unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward) |
653 | 688 | unet.enable_xla_flash_attention(partition_spec=("data", None, None, None)) |
654 | 689 |
|
655 | 690 | vae.requires_grad_(False) |
@@ -810,6 +845,8 @@ def collate_fn(examples): |
810 | 845 | f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}" |
811 | 846 | ) |
812 | 847 | print(f" Total optimization steps = {args.max_train_steps}") |
| 848 | + |
| 849 | + unet = add_checkpoints(unet) |
813 | 850 |
|
814 | 851 | trainer = TrainSD( |
815 | 852 | weight_dtype=weight_dtype, |
|
0 commit comments