@@ -472,7 +472,7 @@ def test__native_dist_model_warning_index_less_localrank(local_rank, world_size)
472472 dist .destroy_process_group ()
473473
474474
475- def _test_dist_spawn_fn (local_rank , backend , world_size , device ):
475+ def _test_dist_spawn_fn (local_rank , backend , world_size , device , ** kwargs ):
476476 from ignite .distributed .utils import _model
477477
478478 assert dist .is_available () and dist .is_initialized ()
@@ -484,12 +484,22 @@ def _test_dist_spawn_fn(local_rank, backend, world_size, device):
484484 assert _model .get_world_size () == world_size
485485 assert _model .device ().type == torch .device (device ).type
486486
487+ if "master_addr" in kwargs :
488+ assert os .environ ["MASTER_ADDR" ] == kwargs ["master_addr" ]
489+ if "master_port" in kwargs :
490+ assert os .environ ["MASTER_PORT" ] == str (kwargs ["master_port" ])
491+
487492
488493def _test__native_dist_model_spawn (backend , num_workers_per_machine , device , init_method = None , ** spawn_kwargs ):
494+ kwargs_dict = {}
495+ for key in ["master_addr" , "master_port" ]:
496+ if key in spawn_kwargs :
497+ kwargs_dict [key ] = spawn_kwargs [key ]
498+
489499 _NativeDistModel .spawn (
490500 _test_dist_spawn_fn ,
491501 args = (backend , num_workers_per_machine , device ),
492- kwargs_dict = {} ,
502+ kwargs_dict = kwargs_dict ,
493503 backend = backend ,
494504 nproc_per_node = num_workers_per_machine ,
495505 init_method = init_method ,
@@ -499,31 +509,56 @@ def _test__native_dist_model_spawn(backend, num_workers_per_machine, device, ini
499509
500510@pytest .mark .distributed
501511@pytest .mark .skipif ("WORLD_SIZE" in os .environ , reason = "Skip if launched as multiproc" )
502- @pytest .mark .parametrize ("init_method" , [None , "env://" , "tcp://0.0.0.0:22334" , "FILE" ])
512+ @pytest .mark .parametrize ("init_method" , [None , "CUSTOM_ADDR_PORT" , " env://" , "tcp://0.0.0.0:22334" , "FILE" ])
503513def test__native_dist_model_spawn_gloo (init_method , dirname ):
514+ spawn_kwargs = {}
515+
504516 if init_method == "FILE" :
505517 init_method = f"file://{ dirname } /shared"
518+ elif init_method == "CUSTOM_ADDR_PORT" :
519+ init_method = None
520+ spawn_kwargs ["master_addr" ] = "0.0.0.0"
521+ spawn_kwargs ["master_port" ] = 2345
506522
507523 nproc = torch .cuda .device_count () if torch .cuda .is_available () else 4
508524 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
509- _test__native_dist_model_spawn ("gloo" , num_workers_per_machine = nproc , device = device , init_method = init_method )
525+ _test__native_dist_model_spawn (
526+ "gloo" , num_workers_per_machine = nproc , device = device , init_method = init_method , ** spawn_kwargs
527+ )
510528 if device .type == "cpu" :
529+ spawn_kwargs ["start_method" ] = "fork"
511530 _test__native_dist_model_spawn (
512- "gloo" , num_workers_per_machine = nproc , device = device , start_method = "fork" , init_method = init_method
531+ "gloo" , num_workers_per_machine = nproc , device = device , init_method = init_method , ** spawn_kwargs
513532 )
514533
534+ if init_method not in [None , "env://" ]:
535+ with pytest .raises (ValueError , match = r"master_addr should be None if init_method is provided" ):
536+ _test__native_dist_model_spawn (
537+ "gloo" , num_workers_per_machine = nproc , device = device , init_method = init_method , master_addr = "abc"
538+ )
539+ with pytest .raises (ValueError , match = r"master_port should be None if init_method is provided" ):
540+ _test__native_dist_model_spawn (
541+ "gloo" , num_workers_per_machine = nproc , device = device , init_method = init_method , master_port = 123
542+ )
543+
515544
516545@pytest .mark .distributed
517546@pytest .mark .skipif ("WORLD_SIZE" in os .environ , reason = "Skip if launched as multiproc" )
518547@pytest .mark .skipif (torch .cuda .device_count () < 1 , reason = "Skip if no GPU" )
519- @pytest .mark .parametrize ("init_method" , [None , "tcp://0.0.0.0:22334" , "FILE" ])
548+ @pytest .mark .parametrize ("init_method" , [None , "CUSTOM_ADDR_PORT" , " tcp://0.0.0.0:22334" , "FILE" ])
520549def test__native_dist_model_spawn_nccl (init_method , dirname ):
550+ spawn_kwargs = {}
551+
521552 if init_method == "FILE" :
522553 init_method = f"file://{ dirname } /shared"
554+ elif init_method == "CUSTOM_ADDR_PORT" :
555+ init_method = None
556+ spawn_kwargs ["master_addr" ] = "0.0.0.0"
557+ spawn_kwargs ["master_port" ] = 2345
523558
524- num_workers_per_machine = torch .cuda .device_count ()
559+ nproc = torch .cuda .device_count ()
525560 _test__native_dist_model_spawn (
526- "nccl" , num_workers_per_machine = num_workers_per_machine , device = "cuda" , init_method = init_method
561+ "nccl" , num_workers_per_machine = nproc , device = "cuda" , init_method = init_method , ** spawn_kwargs
527562 )
528563
529564
0 commit comments