|  | 
| 67 | 67 |     FlowMatchEulerDiscreteScheduler, | 
| 68 | 68 |     FluxPipeline, | 
| 69 | 69 |     FluxTransformer2DModel, | 
|  | 70 | +    ParallelConfig, | 
|  | 71 | +    enable_parallelism, | 
| 70 | 72 | ) | 
| 71 | 73 | from diffusers.optimization import get_scheduler | 
| 72 | 74 | from diffusers.training_utils import ( | 
| @@ -805,6 +807,8 @@ def parse_args(input_args=None): | 
| 805 | 807 |         ], | 
| 806 | 808 |         help="The image interpolation method to use for resizing images.", | 
| 807 | 809 |     ) | 
|  | 810 | +    parser.add_argument("--context_parallel_degree", type=int, default=1, help="The degree for context parallelism.") | 
|  | 811 | +    parser.add_argument("--context_parallel_type", type=str, default="ulysses", help="The type of context parallelism to use. Choose between 'ulysses' and 'ring'.") | 
| 808 | 812 | 
 | 
| 809 | 813 |     if input_args is not None: | 
| 810 | 814 |         args = parser.parse_args(input_args) | 
| @@ -1347,15 +1351,28 @@ def main(args): | 
| 1347 | 1351 | 
 | 
| 1348 | 1352 |     logging_dir = Path(args.output_dir, args.logging_dir) | 
| 1349 | 1353 | 
 | 
|  | 1354 | +    cp_degree = args.context_parallel_degree | 
| 1350 | 1355 |     accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) | 
| 1351 |  | -    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | 
|  | 1356 | +    if cp_degree > 1: | 
|  | 1357 | +        kwargs = [] | 
|  | 1358 | +    else: | 
|  | 1359 | +        kwargs = [DistributedDataParallelKwargs(find_unused_parameters=True)] | 
| 1352 | 1360 |     accelerator = Accelerator( | 
| 1353 | 1361 |         gradient_accumulation_steps=args.gradient_accumulation_steps, | 
| 1354 | 1362 |         mixed_precision=args.mixed_precision, | 
| 1355 | 1363 |         log_with=args.report_to, | 
| 1356 | 1364 |         project_config=accelerator_project_config, | 
| 1357 |  | -        kwargs_handlers=[kwargs], | 
| 1358 |  | -    ) | 
|  | 1365 | +        kwargs_handlers=kwargs, | 
|  | 1366 | +    ) | 
|  | 1367 | +    if cp_degree > 1 and not torch.distributed.is_initialized(): | 
|  | 1368 | +        if not torch.cuda.is_available(): | 
|  | 1369 | +            raise ValueError("Context parallelism is only tested on CUDA devices.") | 
|  | 1370 | +        if os.environ.get("WORLD_SIZE", None) is None: | 
|  | 1371 | +            raise ValueError("Try launching the program with `torchrun --nproc_per_node <NUM_GPUS>` instead of `accelerate launch <NUM_GPUS>`.") | 
|  | 1372 | +        torch.distributed.init_process_group("nccl") | 
|  | 1373 | +        rank = torch.distributed.get_rank() | 
|  | 1374 | +        rank = accelerator.process_index | 
|  | 1375 | +        torch.cuda.set_device(torch.device("cuda", rank % torch.cuda.device_count())) | 
| 1359 | 1376 | 
 | 
| 1360 | 1377 |     # Disable AMP for MPS. | 
| 1361 | 1378 |     if torch.backends.mps.is_available(): | 
| @@ -1977,6 +1994,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): | 
| 1977 | 1994 |         power=args.lr_power, | 
| 1978 | 1995 |     ) | 
| 1979 | 1996 | 
 | 
|  | 1997 | +    # Enable context parallelism | 
|  | 1998 | +    if cp_degree > 1: | 
|  | 1999 | +        ring_degree = cp_degree if args.context_parallel_type == "ring" else None | 
|  | 2000 | +        ulysses_degree = cp_degree if args.context_parallel_type == "ulysses" else None | 
|  | 2001 | +        transformer.parallelize(config=ParallelConfig(ring_degree=ring_degree, ulysses_degree=ulysses_degree)) | 
|  | 2002 | +        transformer.set_attention_backend("_native_cudnn") | 
|  | 2003 | +    parallel_context = enable_parallelism(transformer) if cp_degree > 1 else nullcontext() | 
|  | 2004 | + | 
| 1980 | 2005 |     # Prepare everything with our `accelerator`. | 
| 1981 | 2006 |     if not freeze_text_encoder: | 
| 1982 | 2007 |         if args.enable_t5_ti: | 
| @@ -2131,7 +2156,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): | 
| 2131 | 2156 |                 logger.info(f"PIVOT TRANSFORMER {epoch}") | 
| 2132 | 2157 |                 optimizer.param_groups[0]["lr"] = 0.0 | 
| 2133 | 2158 | 
 | 
| 2134 |  | -            with accelerator.accumulate(models_to_accumulate): | 
|  | 2159 | +            with accelerator.accumulate(models_to_accumulate), parallel_context: | 
| 2135 | 2160 |                 prompts = batch["prompts"] | 
| 2136 | 2161 | 
 | 
| 2137 | 2162 |                 # encode batch prompts when custom prompts are provided for each image - | 
|  | 
0 commit comments