|
63 | 63 | "SM_ENTRY_SCRIPT": "/opt/ml/input/data/code/script.py", |
64 | 64 | }, |
65 | 65 | ) |
66 | | -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file") |
| 66 | +@patch( |
| 67 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file" |
| 68 | +) |
67 | 69 | @patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon") |
68 | | -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node") |
69 | | -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node") |
70 | | -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args") |
71 | | -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command") |
| 70 | +@patch( |
| 71 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node" |
| 72 | +) |
| 73 | +@patch( |
| 74 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node" |
| 75 | +) |
| 76 | +@patch( |
| 77 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args" |
| 78 | +) |
| 79 | +@patch( |
| 80 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command" |
| 81 | +) |
72 | 82 | @patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands") |
73 | 83 | def test_mpi_driver_worker( |
74 | 84 | mock_execute_commands, |
@@ -104,15 +114,27 @@ def test_mpi_driver_worker( |
104 | 114 | "SM_ENTRY_SCRIPT": "script.py", |
105 | 115 | }, |
106 | 116 | ) |
107 | | -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file") |
| 117 | +@patch( |
| 118 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file" |
| 119 | +) |
108 | 120 | @patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon") |
109 | | -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node") |
110 | | -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node") |
| 121 | +@patch( |
| 122 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node" |
| 123 | +) |
| 124 | +@patch( |
| 125 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node" |
| 126 | +) |
111 | 127 | @patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_process_count") |
112 | | -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args") |
113 | | -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command") |
| 128 | +@patch( |
| 129 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args" |
| 130 | +) |
| 131 | +@patch( |
| 132 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command" |
| 133 | +) |
114 | 134 | @patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands") |
115 | | -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_status_file_to_workers") |
| 135 | +@patch( |
| 136 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_status_file_to_workers" |
| 137 | +) |
116 | 138 | def test_mpi_driver_master( |
117 | 139 | mock_write_status_file_to_workers, |
118 | 140 | mock_execute_commands, |
|
0 commit comments