@@ -53,7 +53,7 @@ def _distributed_worker(rank, fn, world_size, kwargs):
5353def distributed_launcher (request ):
5454 n_gpus = getattr (request , "param" , None )
5555 if not torch .cuda .is_available ():
56- pytest .skip ("CUDA required for distributed GPU test" )
56+ pytest .xfail ("CUDA required for distributed GPU test" )
5757 if torch .cuda .device_count () < n_gpus :
5858 pytest .skip (f"requires up to { n_gpus } CUDA devices, found { torch .cuda .device_count ()} " )
5959
@@ -82,8 +82,7 @@ def launch(fn, **kwargs):
8282
8383@pytest .mark .parametrize ("n_expts_shard, n_expts_tot" , [(8 , 512 ), (16 , 64 )])
8484@pytest .mark .parametrize ("affinity_mode" , ["uniform" , "random" ])
85- def test_make_expt_assignment (n_expts_shard , n_expts_tot , affinity_mode ):
86- device = "cuda"
85+ def test_make_expt_assignment (n_expts_shard , n_expts_tot , affinity_mode , device ):
8786 expt_dict = _make_expt_dict_for_mode (n_expts_shard , n_expts_tot , affinity_mode )
8887 expt_assignment = make_expt_assignment (n_expts_shard , n_expts_tot , expt_dict , device )
8988 # mask correctness & uniqueness: each expert set exactly once, and on the right shard
0 commit comments