@@ -105,19 +105,41 @@ def apply_context_parallel(
105105            registry  =  HookRegistry .check_if_exists_or_initialize (m )
106106            registry .register_hook (hook , hook_name )
107107
108-     registry  =  HookRegistry .check_if_exists_or_initialize (module )
109-     hook  =  ContextParallelModelHook (parallel_config )
110-     registry .register_hook (hook , _CONTEXT_PARALLEL_MODEL_HOOK )
108+     # HACK: we cannot use context managers or setattr or similar solutions in an overwritten forward 
109+     # diffusers hook method because Dynamo fails to trace it. Instead, we make use of module hooks 
110+     # available in pytorch to set the parallel context before/after the forward/backward pass. 
111+     # It is dirty, but fullgraph=True tracing works because of this and I haven't found a better solution yet. 
112+     # The previous/older implementation simply did this: 
113+     #     def new_forward(self, ...): 
114+     #         with _parallel_context(parallel_config): 
115+     #             return self.fn_ref.original_forward(*args, **kwargs) 
116+     # TODO: ask help from Pytorch team on how to improve this 
117+     @torch .compiler .disable  
118+     def  forward_pre_hook (module , args ):
119+         module ._diffusers_parallel_config_setter_context  =  _parallel_context (parallel_config )
120+         module ._diffusers_parallel_config_setter_context .__enter__ ()
111121
122+     @torch .compiler .disable  
123+     def  forward_hook (module , args , output ):
124+         if  module ._diffusers_parallel_config_setter_context  is  not None :
125+             module ._diffusers_parallel_config_setter_context .__exit__ (None , None , None )
126+         module ._diffusers_parallel_config_setter_context  =  None 
112127
113- class   ContextParallelModelHook ( ModelHook ): 
114-     def  __init__ ( self ,  parallel_config :  ParallelConfig )  ->   None :
115-         super (). __init__ ( )
116-         self . parallel_config   =   parallel_config 
128+      @ torch . compiler . disable 
129+     def  backward_pre_hook ( module ,  grad_output ) :
130+         module . _diffusers_parallel_config_setter_context   =   _parallel_context ( parallel_config )
131+         module . _diffusers_parallel_config_setter_context . __enter__ () 
117132
118-     def  new_forward (self , module : torch .nn .Module , * args , ** kwargs ):
119-         with  _parallel_context (self .parallel_config ):
120-             return  self .fn_ref .original_forward (* args , ** kwargs )
133+     @torch .compiler .disable  
134+     def  backward_hook (module , grad_output , grad_input ):
135+         if  module ._diffusers_parallel_config_setter_context  is  not None :
136+             module ._diffusers_parallel_config_setter_context .__exit__ (None , None , None )
137+         module ._diffusers_parallel_config_setter_context  =  None 
138+ 
139+     module .register_forward_pre_hook (forward_pre_hook )
140+     module .register_forward_hook (forward_hook )
141+     module .register_full_backward_pre_hook (backward_pre_hook )
142+     module .register_full_backward_hook (backward_hook )
121143
122144
123145class  ContextParallelSplitHook (ModelHook ):
@@ -234,13 +256,15 @@ def post_forward(self, module, output):
234256
235257class  EquipartitionSharder :
236258    @classmethod  
237-     @torch .compiler .disable  
238259    def  shard (cls , tensor : torch .Tensor , dim : int , mesh : torch .distributed .device_mesh .DeviceMesh ) ->  torch .Tensor :
239260        assert  tensor .size ()[dim ] %  mesh .size () ==  0 
240-         return  tensor .chunk (mesh .size (), dim = dim )[mesh .get_rank ()]
261+ 
262+         # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank) 
263+         # return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()] 
264+ 
265+         return  tensor .chunk (mesh .size (), dim = dim )[torch .distributed .get_rank (mesh .get_group ())]
241266
242267    @classmethod  
243-     @torch .compiler .disable  
244268    def  unshard (cls , tensor : torch .Tensor , dim : int , mesh : torch .distributed .device_mesh .DeviceMesh ) ->  torch .Tensor :
245269        tensor  =  tensor .contiguous ()
246270        tensor  =  funcol .all_gather_tensor (tensor , dim , group = mesh .get_group ())
0 commit comments