Skip to content

Commit 215104f

Browse files
committed
update
1 parent fa5d017 commit 215104f

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

src/diffusers/hooks/context_parallel.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,21 @@ def post_forward(self, module, output):
233233
return output[0] if is_tensor else tuple(output)
234234

235235

236+
class AllGatherFunction(torch.autograd.Function):
237+
@staticmethod
238+
def forward(ctx, tensor, dim, group):
239+
ctx.dim = dim
240+
ctx.group = group
241+
ctx.world_size = torch.distributed.get_world_size(group)
242+
ctx.rank = torch.distributed.get_rank(group)
243+
return funcol.all_gather_tensor(tensor, dim, group=group)
244+
245+
@staticmethod
246+
def backward(ctx, grad_output):
247+
grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)
248+
return grad_chunks[ctx.rank], None, None
249+
250+
236251
class EquipartitionSharder:
237252
@classmethod
238253
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
@@ -246,7 +261,7 @@ def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_me
246261
@classmethod
247262
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
248263
tensor = tensor.contiguous()
249-
tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group())
264+
tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group())
250265
return tensor
251266

252267

src/diffusers/models/attention_dispatch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,8 @@ def dispatch_attention_fn(
281281
and not _AttentionBackendRegistry._is_context_parallel_enabled(backend_name)
282282
):
283283
raise ValueError(
284-
f"Backend {backend_name} does not support context parallelism, but a parallel configuration is provided."
284+
f"Backend {backend_name} either does not support context parallelism or context parallelism "
285+
f"was enabled with a world size of 1."
285286
)
286287

287288
kwargs = {

0 commit comments

Comments
 (0)