1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import os
17+
1618import pytest
1719import torch
20+ import torch .multiprocessing as mp
1821
22+ from diffusers .models ._modeling_parallel import ContextParallelConfig
1923from diffusers .models .attention import AttentionModuleMixin
2024from diffusers .models .attention_processor import (
2125 AttnProcessor ,
2226)
2327
24- from ...testing_utils import is_attention , torch_device
28+ from ...testing_utils import is_attention , is_context_parallel , require_torch_multi_accelerator , torch_device
2529
2630
2731@is_attention
@@ -85,9 +89,9 @@ def test_fuse_unfuse_qkv_projections(self):
8589 output_after_fusion = output_after_fusion .to_tuple ()[0 ]
8690
8791 # Verify outputs match
88- assert torch .allclose (
89- output_before_fusion , output_after_fusion , atol = self . base_precision
90- ), "Output should not change after fusing projections"
92+ assert torch .allclose (output_before_fusion , output_after_fusion , atol = self . base_precision ), (
93+ "Output should not change after fusing projections"
94+ )
9195
9296 # Unfuse projections
9397 model .unfuse_qkv_projections ()
@@ -106,9 +110,9 @@ def test_fuse_unfuse_qkv_projections(self):
106110 output_after_unfusion = output_after_unfusion .to_tuple ()[0 ]
107111
108112 # Verify outputs still match
109- assert torch .allclose (
110- output_before_fusion , output_after_unfusion , atol = self . base_precision
111- ), "Output should match original after unfusing projections"
113+ assert torch .allclose (output_before_fusion , output_after_unfusion , atol = self . base_precision ), (
114+ "Output should match original after unfusing projections"
115+ )
112116
113117 def test_get_set_processor (self ):
114118 init_dict = self .get_init_dict ()
@@ -177,3 +181,83 @@ def test_attention_processor_count_mismatch_raises_error(self):
177181 model .set_attn_processor (wrong_processors )
178182
179183 assert "number of processors" in str (exc_info .value ).lower (), "Error should mention processor count mismatch"
184+
185+
186+ def _context_parallel_worker (rank , world_size , model_class , init_dict , cp_dict , inputs_dict , result_queue ):
187+ try :
188+ # Setup distributed environment
189+ os .environ ["MASTER_ADDR" ] = "localhost"
190+ os .environ ["MASTER_PORT" ] = "12355"
191+
192+ torch .distributed .init_process_group (
193+ backend = "nccl" ,
194+ init_method = "env://" ,
195+ world_size = world_size ,
196+ rank = rank ,
197+ )
198+ torch .cuda .set_device (rank )
199+ device = torch .device (f"cuda:{ rank } " )
200+
201+ model = model_class (** init_dict )
202+ model .to (device )
203+ model .eval ()
204+
205+ inputs_on_device = {}
206+ for key , value in inputs_dict .items ():
207+ if isinstance (value , torch .Tensor ):
208+ inputs_on_device [key ] = value .to (device )
209+ else :
210+ inputs_on_device [key ] = value
211+
212+ cp_config = ContextParallelConfig (** cp_dict )
213+ model .enable_parallelism (config = cp_config )
214+
215+ with torch .no_grad ():
216+ output = model (** inputs_on_device )
217+ if isinstance (output , dict ):
218+ output = output .to_tuple ()[0 ]
219+
220+ if rank == 0 :
221+ result_queue .put (("success" , output .shape ))
222+
223+ except Exception as e :
224+ if rank == 0 :
225+ result_queue .put (("error" , str (e )))
226+ finally :
227+ if torch .distributed .is_initialized ():
228+ torch .distributed .destroy_process_group ()
229+
230+
231+ @is_context_parallel
232+ @require_torch_multi_accelerator
233+ class ContextParallelTesterMixin :
234+ base_precision = 1e-3
235+
236+ @pytest .mark .parametrize ("cp_type" , ["ulysses_degree" , "ring_degree" ], ids = ["ulysses" , "ring" ])
237+ def test_context_parallel_inference (self , cp_type ):
238+ if not torch .distributed .is_available ():
239+ pytest .skip ("torch.distributed is not available." )
240+
241+ if not torch .cuda .is_available () or torch .cuda .device_count () < 2 :
242+ pytest .skip ("Context parallel requires at least 2 CUDA devices." )
243+
244+ if not hasattr (self .model_class , "_cp_plan" ) or self .model_class ._cp_plan is None :
245+ pytest .skip ("Model does not have a _cp_plan defined for context parallel inference." )
246+
247+ world_size = 2
248+ init_dict = self .get_init_dict ()
249+ inputs_dict = self .get_dummy_inputs ()
250+ cp_dict = {cp_type : world_size }
251+
252+ ctx = mp .get_context ("spawn" )
253+ result_queue = ctx .Queue ()
254+
255+ mp .spawn (
256+ _context_parallel_worker ,
257+ args = (world_size , self .model_class , init_dict , cp_dict , inputs_dict , result_queue ),
258+ nprocs = world_size ,
259+ join = True ,
260+ )
261+
262+ status , result = result_queue .get (timeout = 60 )
263+ assert status == "success" , f"Context parallel inference failed: { result } "
0 commit comments