|
63 | 63 | "SM_ENTRY_SCRIPT": "/opt/ml/input/data/code/script.py",
|
64 | 64 | },
|
65 | 65 | )
|
66 |
| -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file") |
| 66 | +@patch( |
| 67 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file" |
| 68 | +) |
67 | 69 | @patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon")
|
68 |
| -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node") |
69 |
| -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node") |
70 |
| -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args") |
71 |
| -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command") |
| 70 | +@patch( |
| 71 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node" |
| 72 | +) |
| 73 | +@patch( |
| 74 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node" |
| 75 | +) |
| 76 | +@patch( |
| 77 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args" |
| 78 | +) |
| 79 | +@patch( |
| 80 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command" |
| 81 | +) |
72 | 82 | @patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands")
|
73 | 83 | def test_mpi_driver_worker(
|
74 | 84 | mock_execute_commands,
|
@@ -104,15 +114,27 @@ def test_mpi_driver_worker(
|
104 | 114 | "SM_ENTRY_SCRIPT": "script.py",
|
105 | 115 | },
|
106 | 116 | )
|
107 |
| -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file") |
| 117 | +@patch( |
| 118 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file" |
| 119 | +) |
108 | 120 | @patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon")
|
109 |
| -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node") |
110 |
| -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node") |
| 121 | +@patch( |
| 122 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node" |
| 123 | +) |
| 124 | +@patch( |
| 125 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node" |
| 126 | +) |
111 | 127 | @patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_process_count")
|
112 |
| -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args") |
113 |
| -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command") |
| 128 | +@patch( |
| 129 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args" |
| 130 | +) |
| 131 | +@patch( |
| 132 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command" |
| 133 | +) |
114 | 134 | @patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands")
|
115 |
| -@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_status_file_to_workers") |
| 135 | +@patch( |
| 136 | + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_status_file_to_workers" |
| 137 | +) |
116 | 138 | def test_mpi_driver_master(
|
117 | 139 | mock_write_status_file_to_workers,
|
118 | 140 | mock_execute_commands,
|
|
0 commit comments