Skip to content

Commit 62f164d

Browse files
committed
make torch compile compatible
1 parent 171152f commit 62f164d

File tree

2 files changed

+53
-19
lines changed

2 files changed

+53
-19
lines changed

src/diffusers/hooks/context_parallel.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,41 @@ def apply_context_parallel(
105105
registry = HookRegistry.check_if_exists_or_initialize(m)
106106
registry.register_hook(hook, hook_name)
107107

108-
registry = HookRegistry.check_if_exists_or_initialize(module)
109-
hook = ContextParallelModelHook(parallel_config)
110-
registry.register_hook(hook, _CONTEXT_PARALLEL_MODEL_HOOK)
108+
# HACK: we cannot use context managers or setattr or similar solutions in an overwritten forward
109+
# diffusers hook method because Dynamo fails to trace it. Instead, we make use of module hooks
110+
# available in pytorch to set the parallel context before/after the forward/backward pass.
111+
# It is dirty, but fullgraph=True tracing works because of this and I haven't found a better solution yet.
112+
# The previous/older implementation simply did this:
113+
# def new_forward(self, ...):
114+
# with _parallel_context(parallel_config):
115+
# return self.fn_ref.original_forward(*args, **kwargs)
116+
# TODO: ask help from Pytorch team on how to improve this
117+
@torch.compiler.disable
118+
def forward_pre_hook(module, args):
119+
module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config)
120+
module._diffusers_parallel_config_setter_context.__enter__()
111121

122+
@torch.compiler.disable
123+
def forward_hook(module, args, output):
124+
if module._diffusers_parallel_config_setter_context is not None:
125+
module._diffusers_parallel_config_setter_context.__exit__(None, None, None)
126+
module._diffusers_parallel_config_setter_context = None
112127

113-
class ContextParallelModelHook(ModelHook):
114-
def __init__(self, parallel_config: ParallelConfig) -> None:
115-
super().__init__()
116-
self.parallel_config = parallel_config
128+
@torch.compiler.disable
129+
def backward_pre_hook(module, grad_output):
130+
module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config)
131+
module._diffusers_parallel_config_setter_context.__enter__()
117132

118-
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
119-
with _parallel_context(self.parallel_config):
120-
return self.fn_ref.original_forward(*args, **kwargs)
133+
@torch.compiler.disable
134+
def backward_hook(module, grad_output, grad_input):
135+
if module._diffusers_parallel_config_setter_context is not None:
136+
module._diffusers_parallel_config_setter_context.__exit__(None, None, None)
137+
module._diffusers_parallel_config_setter_context = None
138+
139+
module.register_forward_pre_hook(forward_pre_hook)
140+
module.register_forward_hook(forward_hook)
141+
module.register_full_backward_pre_hook(backward_pre_hook)
142+
module.register_full_backward_hook(backward_hook)
121143

122144

123145
class ContextParallelSplitHook(ModelHook):
@@ -234,13 +256,15 @@ def post_forward(self, module, output):
234256

235257
class EquipartitionSharder:
236258
@classmethod
237-
@torch.compiler.disable
238259
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
239260
assert tensor.size()[dim] % mesh.size() == 0
240-
return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
261+
262+
# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
263+
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
264+
265+
return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
241266

242267
@classmethod
243-
@torch.compiler.disable
244268
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
245269
tensor = tensor.contiguous()
246270
tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group())

src/diffusers/models/attention_dispatch.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import torch
2323
import torch.distributed._functional_collectives as funcol
24+
import torch.distributed.tensor
2425

2526
from ..utils import (
2627
get_logger,
@@ -245,9 +246,6 @@ def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIV
245246

246247
@contextlib.contextmanager
247248
def _parallel_context(parallel_config: "ParallelConfig"):
248-
"""
249-
Context manager to set the parallel configuration for attention backends that support it.
250-
"""
251249
old_parallel_config = _AttentionBackendRegistry._parallel_config
252250
_AttentionBackendRegistry._parallel_config = parallel_config
253251

@@ -789,6 +787,16 @@ def backward(
789787
# ===== Context parallel =====
790788

791789

790+
# Reference:
791+
# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827
792+
# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246
793+
# For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method):
794+
def _wait_tensor(tensor):
795+
if isinstance(tensor, funcol.AsyncCollectiveTensor):
796+
tensor = tensor.wait()
797+
return tensor
798+
799+
792800
class TemplatedRingAttention(torch.autograd.Function):
793801
@staticmethod
794802
def forward(
@@ -875,20 +883,22 @@ def forward(
875883
x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
876884
for x in (query, key, value)
877885
)
878-
query, key, value = (funcol.all_to_all_single(x, None, None, group=group).wait() for x in (query, key, value))
886+
query, key, value = (
887+
_wait_tensor(funcol.all_to_all_single(x, None, None, group=group)) for x in (query, key, value)
888+
)
879889
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
880890

881891
out = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse)
882892
if return_lse:
883893
out, lse, *_ = out
884894

885895
out = out.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
886-
out = funcol.all_to_all_single(out, None, None, group=group).wait()
896+
out = _wait_tensor(funcol.all_to_all_single(out, None, None, group=group))
887897
out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
888898

889899
if return_lse:
890900
lse = lse.reshape(B, world_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
891-
lse = funcol.all_to_all_single(lse, None, None, group=group).wait()
901+
lse = _wait_tensor(funcol.all_to_all_single(lse, None, None, group=group))
892902
lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
893903
else:
894904
lse = None

0 commit comments

Comments
 (0)