Skip to content

Commit b152915

Browse files
committed
added unit tests for bootstrap_environment remote
1 parent 0dea502 commit b152915

File tree

1 file changed

+249
-0
lines changed

1 file changed

+249
-0
lines changed

tests/unit/sagemaker/remote_function/test_job.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@
4949
_prepare_dependencies_and_pre_execution_scripts,
5050
)
5151

52+
from sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment import (
53+
set_env,
54+
safe_serialize,
55+
)
56+
5257

5358
REGION = "us-west-2"
5459
TRAINING_JOB_ARN = "training-job-arn"
@@ -68,6 +73,87 @@
6873
EXPECTED_OUTPUT_URI = S3_URI + "/output"
6974
EXPECTED_DEPENDENCIES_URI = S3_URI + "/additional_dependencies/requirements.txt"
7075

76+
# flake8: noqa
77+
EXPECTED_ENV_SINGLE_NODE_CPU = """
78+
export SM_MODEL_DIR='/opt/ml/model'
79+
export SM_INPUT_DIR='/opt/ml/input'
80+
export SM_INPUT_DATA_DIR='/opt/ml/input/data'
81+
export SM_INPUT_CONFIG_DIR='/opt/ml/input/config'
82+
export SM_OUTPUT_DIR='/opt/ml/output'
83+
export SM_OUTPUT_FAILURE='/opt/ml/output/failure'
84+
export SM_OUTPUT_DATA_DIR='/opt/ml/output/data'
85+
export SM_MASTER_ADDR='algo-1'
86+
export SM_MASTER_PORT='7777'
87+
export SM_CURRENT_HOST='algo-1'
88+
export SM_CURRENT_INSTANCE_TYPE='ml.t3.xlarge'
89+
export SM_HOSTS='["algo-1"]'
90+
export SM_NETWORK_INTERFACE_NAME='eth0'
91+
export SM_HOST_COUNT='1'
92+
export SM_CURRENT_HOST_RANK='0'
93+
export SM_NUM_CPUS='4'
94+
export SM_NUM_GPUS='0'
95+
export SM_NUM_NEURONS='0'
96+
export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.t3.xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.t3.xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}'
97+
export SM_NPROC_PER_NODE='4'
98+
export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.t3.xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 4, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 4, "num_gpus": 0, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.t3.xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.t3.xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}'
99+
export NCCL_SOCKET_IFNAME='eth0'
100+
export NCCL_PROTO='simple'
101+
"""
102+
103+
# flake8: noqa
104+
EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS = """
105+
export SM_MODEL_DIR='/opt/ml/model'
106+
export SM_INPUT_DIR='/opt/ml/input'
107+
export SM_INPUT_DATA_DIR='/opt/ml/input/data'
108+
export SM_INPUT_CONFIG_DIR='/opt/ml/input/config'
109+
export SM_OUTPUT_DIR='/opt/ml/output'
110+
export SM_OUTPUT_FAILURE='/opt/ml/output/failure'
111+
export SM_OUTPUT_DATA_DIR='/opt/ml/output/data'
112+
export SM_MASTER_ADDR='algo-1'
113+
export SM_MASTER_PORT='7777'
114+
export SM_CURRENT_HOST='algo-1'
115+
export SM_CURRENT_INSTANCE_TYPE='ml.g5.12xlarge'
116+
export SM_HOSTS='["algo-1"]'
117+
export SM_NETWORK_INTERFACE_NAME='eth0'
118+
export SM_HOST_COUNT='1'
119+
export SM_CURRENT_HOST_RANK='0'
120+
export SM_NUM_CPUS='48'
121+
export SM_NUM_GPUS='4'
122+
export SM_NUM_NEURONS='0'
123+
export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}'
124+
export SM_NPROC_PER_NODE='4'
125+
export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 4, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 48, "num_gpus": 4, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}'
126+
export NCCL_SOCKET_IFNAME='eth0'
127+
export NCCL_PROTO='simple'
128+
"""
129+
130+
# flake8: noqa
131+
EXPECTED_ENV_MULTI_NODE_MULTI_GPUS = """
132+
export SM_MODEL_DIR='/opt/ml/model'
133+
export SM_INPUT_DIR='/opt/ml/input'
134+
export SM_INPUT_DATA_DIR='/opt/ml/input/data'
135+
export SM_INPUT_CONFIG_DIR='/opt/ml/input/config'
136+
export SM_OUTPUT_DIR='/opt/ml/output'
137+
export SM_OUTPUT_FAILURE='/opt/ml/output/failure'
138+
export SM_OUTPUT_DATA_DIR='/opt/ml/output/data'
139+
export SM_MASTER_ADDR='algo-1'
140+
export SM_MASTER_PORT='7777'
141+
export SM_CURRENT_HOST='algo-1'
142+
export SM_CURRENT_INSTANCE_TYPE='ml.g5.2xlarge'
143+
export SM_HOSTS='["algo-1", "algo-2", "algo-3", "algo-4"]'
144+
export SM_NETWORK_INTERFACE_NAME='eth0'
145+
export SM_HOST_COUNT='4'
146+
export SM_CURRENT_HOST_RANK='0'
147+
export SM_NUM_CPUS='8'
148+
export SM_NUM_GPUS='1'
149+
export SM_NUM_NEURONS='0'
150+
export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}'
151+
export SM_NPROC_PER_NODE='1'
152+
export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.2xlarge", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "host_count": 4, "nproc_per_node": 1, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 8, "num_gpus": 1, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}}'
153+
export NCCL_SOCKET_IFNAME='eth0'
154+
export NCCL_PROTO='simple'
155+
"""
156+
71157
DESCRIBE_TRAINING_JOB_RESPONSE = {
72158
"TrainingJobArn": TRAINING_JOB_ARN,
73159
"TrainingJobStatus": "{}",
@@ -79,6 +165,8 @@
79165
"OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"},
80166
}
81167

168+
OUTPUT_FILE = os.path.join(os.path.dirname(__file__), "sm_training.env")
169+
82170
TEST_JOB_NAME = "my-job-name"
83171
TEST_PIPELINE_NAME = "my-pipeline"
84172
TEST_EXP_NAME = "my-exp-name"
@@ -1866,3 +1954,164 @@ def test_start_with_torchrun_multi_node(
18661954
EnableManagedSpotTraining=False,
18671955
Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
18681956
)
1957+
1958+
1959+
@patch(
1960+
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus",
1961+
return_value=4,
1962+
)
1963+
@patch(
1964+
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus",
1965+
return_value=0,
1966+
)
1967+
@patch(
1968+
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons",
1969+
return_value=0,
1970+
)
1971+
@patch(
1972+
"sagemaker.modules.train.container_drivers.scripts.environment.safe_serialize",
1973+
side_effect=safe_serialize,
1974+
)
1975+
def test_set_env_single_node_cpu(
1976+
mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons
1977+
):
1978+
with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}):
1979+
set_env(
1980+
resource_config=dict(
1981+
current_host="algo-1",
1982+
hosts=["algo-1"],
1983+
current_group_name="homogeneousCluster",
1984+
current_instance_type="ml.t3.xlarge",
1985+
instance_groups=[
1986+
dict(
1987+
instance_group_name="homogeneousCluster",
1988+
instance_type="ml.t3.xlarge",
1989+
hosts=["algo-1"],
1990+
)
1991+
],
1992+
network_interface_name="eth0",
1993+
),
1994+
output_file=OUTPUT_FILE,
1995+
)
1996+
1997+
mock_num_cpus.assert_called_once()
1998+
mock_num_gpus.assert_called_once()
1999+
mock_num_neurons.assert_called_once()
2000+
2001+
with open(OUTPUT_FILE, "r") as f:
2002+
env_file = f.read().strip()
2003+
expected_env = _remove_extra_lines(EXPECTED_ENV_SINGLE_NODE_CPU)
2004+
env_file = _remove_extra_lines(env_file)
2005+
2006+
assert env_file == expected_env
2007+
os.remove(OUTPUT_FILE)
2008+
assert not os.path.exists(OUTPUT_FILE)
2009+
2010+
2011+
@patch(
2012+
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus",
2013+
return_value=48,
2014+
)
2015+
@patch(
2016+
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus",
2017+
return_value=4,
2018+
)
2019+
@patch(
2020+
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons",
2021+
return_value=0,
2022+
)
2023+
@patch(
2024+
"sagemaker.modules.train.container_drivers.scripts.environment.safe_serialize",
2025+
side_effect=safe_serialize,
2026+
)
2027+
def test_set_env_single_node_multi_gpu(
2028+
mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons
2029+
):
2030+
with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}):
2031+
set_env(
2032+
resource_config=dict(
2033+
current_host="algo-1",
2034+
hosts=["algo-1"],
2035+
current_group_name="homogeneousCluster",
2036+
current_instance_type="ml.g5.12xlarge",
2037+
instance_groups=[
2038+
dict(
2039+
instance_group_name="homogeneousCluster",
2040+
instance_type="ml.g5.12xlarge",
2041+
hosts=["algo-1"],
2042+
)
2043+
],
2044+
network_interface_name="eth0",
2045+
),
2046+
output_file=OUTPUT_FILE,
2047+
)
2048+
2049+
mock_num_cpus.assert_called_once()
2050+
mock_num_gpus.assert_called_once()
2051+
mock_num_neurons.assert_called_once()
2052+
2053+
with open(OUTPUT_FILE, "r") as f:
2054+
env_file = f.read().strip()
2055+
expected_env = _remove_extra_lines(EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS)
2056+
env_file = _remove_extra_lines(env_file)
2057+
2058+
assert env_file == expected_env
2059+
os.remove(OUTPUT_FILE)
2060+
assert not os.path.exists(OUTPUT_FILE)
2061+
2062+
2063+
@patch(
2064+
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus",
2065+
return_value=8,
2066+
)
2067+
@patch(
2068+
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus",
2069+
return_value=1,
2070+
)
2071+
@patch(
2072+
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons",
2073+
return_value=0,
2074+
)
2075+
@patch(
2076+
"sagemaker.modules.train.container_drivers.scripts.environment.safe_serialize",
2077+
side_effect=safe_serialize,
2078+
)
2079+
def test_set_env_multi_node_multi_gpu(
2080+
mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons
2081+
):
2082+
with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}):
2083+
set_env(
2084+
resource_config=dict(
2085+
current_host="algo-1",
2086+
hosts=["algo-1", "algo-2", "algo-3", "algo-4"],
2087+
current_group_name="homogeneousCluster",
2088+
current_instance_type="ml.g5.2xlarge",
2089+
instance_groups=[
2090+
dict(
2091+
instance_group_name="homogeneousCluster",
2092+
instance_type="ml.g5.2xlarge",
2093+
hosts=["algo-4", "algo-2", "algo-1", "algo-3"],
2094+
)
2095+
],
2096+
network_interface_name="eth0",
2097+
),
2098+
output_file=OUTPUT_FILE,
2099+
)
2100+
2101+
mock_num_cpus.assert_called_once()
2102+
mock_num_gpus.assert_called_once()
2103+
mock_num_neurons.assert_called_once()
2104+
2105+
with open(OUTPUT_FILE, "r") as f:
2106+
env_file = f.read().strip()
2107+
expected_env = _remove_extra_lines(EXPECTED_ENV_MULTI_NODE_MULTI_GPUS)
2108+
env_file = _remove_extra_lines(env_file)
2109+
2110+
assert env_file == expected_env
2111+
os.remove(OUTPUT_FILE)
2112+
assert not os.path.exists(OUTPUT_FILE)
2113+
2114+
2115+
def _remove_extra_lines(string):
2116+
"""Removes extra blank lines from a string."""
2117+
return "\n".join([line for line in string.splitlines() if line.strip()])

0 commit comments

Comments
 (0)