Skip to content

Commit 2c3c2e9

Browse files
committed
update paths
1 parent b562d69 commit 2c3c2e9

File tree

1 file changed

+33
-11
lines changed

1 file changed

+33
-11
lines changed

tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,22 @@
6363
"SM_ENTRY_SCRIPT": "/opt/ml/input/data/code/script.py",
6464
},
6565
)
66-
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file")
66+
@patch(
67+
"sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file"
68+
)
6769
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon")
68-
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node")
69-
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node")
70-
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args")
71-
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command")
70+
@patch(
71+
"sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node"
72+
)
73+
@patch(
74+
"sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node"
75+
)
76+
@patch(
77+
"sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args"
78+
)
79+
@patch(
80+
"sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command"
81+
)
7282
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands")
7383
def test_mpi_driver_worker(
7484
mock_execute_commands,
@@ -104,15 +114,27 @@ def test_mpi_driver_worker(
104114
"SM_ENTRY_SCRIPT": "script.py",
105115
},
106116
)
107-
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file")
117+
@patch(
118+
"sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file"
119+
)
108120
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon")
109-
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node")
110-
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node")
121+
@patch(
122+
"sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node"
123+
)
124+
@patch(
125+
"sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node"
126+
)
111127
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_process_count")
112-
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args")
113-
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command")
128+
@patch(
129+
"sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args"
130+
)
131+
@patch(
132+
"sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command"
133+
)
114134
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands")
115-
@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_status_file_to_workers")
135+
@patch(
136+
"sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_status_file_to_workers"
137+
)
116138
def test_mpi_driver_master(
117139
mock_write_status_file_to_workers,
118140
mock_execute_commands,

0 commit comments

Comments
 (0)