Skip to content

Commit 0018b62

Browse files
committed
Revert "try to make dreambooth script work; accelerator backward not playing well"
This reverts commit 768d0ea.
1 parent 768d0ea commit 0018b62

File tree

1 file changed

+4
-29
lines changed

1 file changed

+4
-29
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,6 @@
6767
FlowMatchEulerDiscreteScheduler,
6868
FluxPipeline,
6969
FluxTransformer2DModel,
70-
ParallelConfig,
71-
enable_parallelism,
7270
)
7371
from diffusers.optimization import get_scheduler
7472
from diffusers.training_utils import (
@@ -807,8 +805,6 @@ def parse_args(input_args=None):
807805
],
808806
help="The image interpolation method to use for resizing images.",
809807
)
810-
parser.add_argument("--context_parallel_degree", type=int, default=1, help="The degree for context parallelism.")
811-
parser.add_argument("--context_parallel_type", type=str, default="ulysses", help="The type of context parallelism to use. Choose between 'ulysses' and 'ring'.")
812808

813809
if input_args is not None:
814810
args = parser.parse_args(input_args)
@@ -1351,28 +1347,15 @@ def main(args):
13511347

13521348
logging_dir = Path(args.output_dir, args.logging_dir)
13531349

1354-
cp_degree = args.context_parallel_degree
13551350
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
1356-
if cp_degree > 1:
1357-
kwargs = []
1358-
else:
1359-
kwargs = [DistributedDataParallelKwargs(find_unused_parameters=True)]
1351+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
13601352
accelerator = Accelerator(
13611353
gradient_accumulation_steps=args.gradient_accumulation_steps,
13621354
mixed_precision=args.mixed_precision,
13631355
log_with=args.report_to,
13641356
project_config=accelerator_project_config,
1365-
kwargs_handlers=kwargs,
1366-
)
1367-
if cp_degree > 1 and not torch.distributed.is_initialized():
1368-
if not torch.cuda.is_available():
1369-
raise ValueError("Context parallelism is only tested on CUDA devices.")
1370-
if os.environ.get("WORLD_SIZE", None) is None:
1371-
raise ValueError("Try launching the program with `torchrun --nproc_per_node <NUM_GPUS>` instead of `accelerate launch <NUM_GPUS>`.")
1372-
torch.distributed.init_process_group("nccl")
1373-
rank = torch.distributed.get_rank()
1374-
rank = accelerator.process_index
1375-
torch.cuda.set_device(torch.device("cuda", rank % torch.cuda.device_count()))
1357+
kwargs_handlers=[kwargs],
1358+
)
13761359

13771360
# Disable AMP for MPS.
13781361
if torch.backends.mps.is_available():
@@ -1994,14 +1977,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19941977
power=args.lr_power,
19951978
)
19961979

1997-
# Enable context parallelism
1998-
if cp_degree > 1:
1999-
ring_degree = cp_degree if args.context_parallel_type == "ring" else None
2000-
ulysses_degree = cp_degree if args.context_parallel_type == "ulysses" else None
2001-
transformer.parallelize(config=ParallelConfig(ring_degree=ring_degree, ulysses_degree=ulysses_degree))
2002-
transformer.set_attention_backend("_native_cudnn")
2003-
parallel_context = enable_parallelism(transformer) if cp_degree > 1 else nullcontext()
2004-
20051980
# Prepare everything with our `accelerator`.
20061981
if not freeze_text_encoder:
20071982
if args.enable_t5_ti:
@@ -2156,7 +2131,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21562131
logger.info(f"PIVOT TRANSFORMER {epoch}")
21572132
optimizer.param_groups[0]["lr"] = 0.0
21582133

2159-
with accelerator.accumulate(models_to_accumulate), parallel_context:
2134+
with accelerator.accumulate(models_to_accumulate):
21602135
prompts = batch["prompts"]
21612136

21622137
# encode batch prompts when custom prompts are provided for each image -

0 commit comments

Comments
 (0)