Skip to content

Commit 4fe2747

Browse files
committed
completed unit tests
1 parent 841af92 commit 4fe2747

File tree

4 files changed

+269
-7
lines changed

4 files changed

+269
-7
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -729,8 +729,8 @@ def __init__(
729729
raise ValueError("max_parallel_jobs must be greater than 0.")
730730

731731
if instance_count > 1 and not (
732-
(spark_config is not None and not use_torchrun)
733-
or (spark_config is None and use_torchrun)
732+
(spark_config is not None and not use_torchrun)
733+
or (spark_config is None and use_torchrun)
734734
):
735735
raise ValueError(
736736
"Remote function do not support training on multi instances "

src/sagemaker/remote_function/job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def __init__(
281281
spark_config: SparkConfig = None,
282282
use_spot_instances=False,
283283
max_wait_time_in_seconds=None,
284-
use_torchrun=False,
284+
use_torchrun: bool = False,
285285
):
286286
"""Initialize a _JobSettings instance which configures the remote job.
287287

src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -496,9 +496,6 @@ def main(sys_args=None):
496496
exit_code = DEFAULT_FAILURE_CODE
497497

498498
try:
499-
with open(RESOURCE_CONFIG, "r") as f:
500-
resource_config = json.load(f)
501-
502499
args = _parse_args(sys_args)
503500
client_python_version = args.client_python_version
504501
client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version
@@ -535,7 +532,15 @@ def main(sys_args=None):
535532
client_sagemaker_pysdk_version
536533
)
537534

538-
set_env(resource_config=resource_config)
535+
if os.path.exists(RESOURCE_CONFIG):
536+
try:
537+
logger.info(f"Found {RESOURCE_CONFIG}")
538+
with open(RESOURCE_CONFIG, "r") as f:
539+
resource_config = json.load(f)
540+
set_env(resource_config=resource_config)
541+
except (json.JSONDecodeError, FileNotFoundError) as e:
542+
# Optionally, you might want to log this error
543+
logger.info(f"Error processing {RESOURCE_CONFIG}: {str(e)}")
539544

540545
exit_code = SUCCESS_EXIT_CODE
541546
except Exception as e: # pylint: disable=broad-except

tests/unit/sagemaker/remote_function/test_job.py

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def test_start(
389389
s3_base_uri=f"{S3_URI}/{job.job_name}",
390390
s3_kms_key=None,
391391
sagemaker_session=session(),
392+
use_torchrun=False,
392393
)
393394

394395
mock_dependency_upload.assert_called_once_with(
@@ -670,6 +671,7 @@ def test_start_with_complete_job_settings(
670671
s3_base_uri=f"{S3_URI}/{job.job_name}",
671672
s3_kms_key=job_settings.s3_kms_key,
672673
sagemaker_session=session(),
674+
use_torchrun=False,
673675
)
674676

675677
mock_user_workspace_upload.assert_called_once_with(
@@ -840,6 +842,7 @@ def test_get_train_args_under_pipeline_context(
840842
s3_base_uri=s3_base_uri,
841843
s3_kms_key=job_settings.s3_kms_key,
842844
sagemaker_session=session(),
845+
use_torchrun=False,
843846
)
844847

845848
mock_user_workspace_upload.assert_called_once_with(
@@ -1014,6 +1017,7 @@ def test_start_with_spark(
10141017
s3_base_uri=f"{S3_URI}/{job.job_name}",
10151018
s3_kms_key=None,
10161019
sagemaker_session=session(),
1020+
use_torchrun=False,
10171021
)
10181022

10191023
session().sagemaker_client.create_training_job.assert_called_once_with(
@@ -1601,3 +1605,256 @@ def test_extend_spark_config_to_request(
16011605
}
16021606
],
16031607
)
1608+
1609+
1610+
@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run)
1611+
@patch("secrets.token_hex", return_value=HMAC_KEY)
1612+
@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri")
1613+
@patch(
1614+
"sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri"
1615+
)
1616+
@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager")
1617+
@patch("sagemaker.remote_function.job.StoredFunction")
1618+
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
1619+
def test_start_with_torchrun_single_node(
1620+
session,
1621+
mock_stored_function,
1622+
mock_runtime_manager,
1623+
mock_script_upload,
1624+
mock_dependency_upload,
1625+
secret_token,
1626+
):
1627+
1628+
job_settings = _JobSettings(
1629+
image_uri=IMAGE,
1630+
s3_root_uri=S3_URI,
1631+
role=ROLE_ARN,
1632+
include_local_workdir=True,
1633+
instance_type="ml.g5.12xlarge",
1634+
encrypt_inter_container_traffic=True,
1635+
use_torchrun=True,
1636+
)
1637+
1638+
job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4})
1639+
1640+
assert job.job_name.startswith("job-function")
1641+
1642+
mock_stored_function.assert_called_once_with(
1643+
sagemaker_session=session(),
1644+
s3_base_uri=f"{S3_URI}/{job.job_name}",
1645+
hmac_key=HMAC_KEY,
1646+
s3_kms_key=None,
1647+
)
1648+
1649+
mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4})
1650+
1651+
local_dependencies_path = mock_runtime_manager().snapshot()
1652+
mock_python_version = mock_runtime_manager()._current_python_version()
1653+
mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version()
1654+
1655+
mock_script_upload.assert_called_once_with(
1656+
spark_config=None,
1657+
s3_base_uri=f"{S3_URI}/{job.job_name}",
1658+
s3_kms_key=None,
1659+
sagemaker_session=session(),
1660+
use_torchrun=True,
1661+
)
1662+
1663+
mock_dependency_upload.assert_called_once_with(
1664+
local_dependencies_path=local_dependencies_path,
1665+
include_local_workdir=True,
1666+
pre_execution_commands=None,
1667+
pre_execution_script_local_path=None,
1668+
s3_base_uri=f"{S3_URI}/{job.job_name}",
1669+
s3_kms_key=None,
1670+
sagemaker_session=session(),
1671+
custom_file_filter=None,
1672+
)
1673+
1674+
session().sagemaker_client.create_training_job.assert_called_once_with(
1675+
TrainingJobName=job.job_name,
1676+
RoleArn=ROLE_ARN,
1677+
StoppingCondition={"MaxRuntimeInSeconds": 86400},
1678+
RetryStrategy={"MaximumRetryAttempts": 1},
1679+
InputDataConfig=[
1680+
dict(
1681+
ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME,
1682+
DataSource={
1683+
"S3DataSource": {
1684+
"S3Uri": mock_script_upload.return_value,
1685+
"S3DataType": "S3Prefix",
1686+
}
1687+
},
1688+
),
1689+
dict(
1690+
ChannelName=REMOTE_FUNCTION_WORKSPACE,
1691+
DataSource={
1692+
"S3DataSource": {
1693+
"S3Uri": mock_dependency_upload.return_value,
1694+
"S3DataType": "S3Prefix",
1695+
}
1696+
},
1697+
),
1698+
],
1699+
OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"},
1700+
AlgorithmSpecification=dict(
1701+
TrainingImage=IMAGE,
1702+
TrainingInputMode="File",
1703+
ContainerEntrypoint=[
1704+
"/bin/bash",
1705+
"/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh",
1706+
],
1707+
ContainerArguments=[
1708+
"--s3_base_uri",
1709+
f"{S3_URI}/{job.job_name}",
1710+
"--region",
1711+
TEST_REGION,
1712+
"--client_python_version",
1713+
mock_python_version,
1714+
"--client_sagemaker_pysdk_version",
1715+
mock_sagemaker_pysdk_version,
1716+
"--dependency_settings",
1717+
'{"dependency_file": null}',
1718+
"--run_in_context",
1719+
'{"experiment_name": "my-exp-name", "run_name": "my-run-name"}',
1720+
],
1721+
),
1722+
ResourceConfig=dict(
1723+
VolumeSizeInGB=30,
1724+
InstanceCount=1,
1725+
InstanceType="ml.g5.12xlarge",
1726+
KeepAlivePeriodInSeconds=0,
1727+
),
1728+
EnableNetworkIsolation=False,
1729+
EnableInterContainerTrafficEncryption=True,
1730+
EnableManagedSpotTraining=False,
1731+
Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
1732+
)
1733+
1734+
1735+
@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run)
1736+
@patch("secrets.token_hex", return_value=HMAC_KEY)
1737+
@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri")
1738+
@patch(
1739+
"sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri"
1740+
)
1741+
@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager")
1742+
@patch("sagemaker.remote_function.job.StoredFunction")
1743+
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
1744+
def test_start_with_torchrun_multi_node(
1745+
session,
1746+
mock_stored_function,
1747+
mock_runtime_manager,
1748+
mock_script_upload,
1749+
mock_dependency_upload,
1750+
secret_token,
1751+
):
1752+
1753+
job_settings = _JobSettings(
1754+
image_uri=IMAGE,
1755+
s3_root_uri=S3_URI,
1756+
role=ROLE_ARN,
1757+
include_local_workdir=True,
1758+
instance_count=2,
1759+
instance_type="ml.g5.2xlarge",
1760+
encrypt_inter_container_traffic=True,
1761+
use_torchrun=True,
1762+
)
1763+
1764+
job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4})
1765+
1766+
assert job.job_name.startswith("job-function")
1767+
1768+
mock_stored_function.assert_called_once_with(
1769+
sagemaker_session=session(),
1770+
s3_base_uri=f"{S3_URI}/{job.job_name}",
1771+
hmac_key=HMAC_KEY,
1772+
s3_kms_key=None,
1773+
)
1774+
1775+
mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4})
1776+
1777+
local_dependencies_path = mock_runtime_manager().snapshot()
1778+
mock_python_version = mock_runtime_manager()._current_python_version()
1779+
mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version()
1780+
1781+
mock_script_upload.assert_called_once_with(
1782+
spark_config=None,
1783+
s3_base_uri=f"{S3_URI}/{job.job_name}",
1784+
s3_kms_key=None,
1785+
sagemaker_session=session(),
1786+
use_torchrun=True,
1787+
)
1788+
1789+
mock_dependency_upload.assert_called_once_with(
1790+
local_dependencies_path=local_dependencies_path,
1791+
include_local_workdir=True,
1792+
pre_execution_commands=None,
1793+
pre_execution_script_local_path=None,
1794+
s3_base_uri=f"{S3_URI}/{job.job_name}",
1795+
s3_kms_key=None,
1796+
sagemaker_session=session(),
1797+
custom_file_filter=None,
1798+
)
1799+
1800+
session().sagemaker_client.create_training_job.assert_called_once_with(
1801+
TrainingJobName=job.job_name,
1802+
RoleArn=ROLE_ARN,
1803+
StoppingCondition={"MaxRuntimeInSeconds": 86400},
1804+
RetryStrategy={"MaximumRetryAttempts": 1},
1805+
InputDataConfig=[
1806+
dict(
1807+
ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME,
1808+
DataSource={
1809+
"S3DataSource": {
1810+
"S3Uri": mock_script_upload.return_value,
1811+
"S3DataType": "S3Prefix",
1812+
"S3DataDistributionType": "FullyReplicated",
1813+
}
1814+
},
1815+
),
1816+
dict(
1817+
ChannelName=REMOTE_FUNCTION_WORKSPACE,
1818+
DataSource={
1819+
"S3DataSource": {
1820+
"S3Uri": mock_dependency_upload.return_value,
1821+
"S3DataType": "S3Prefix",
1822+
"S3DataDistributionType": "FullyReplicated",
1823+
}
1824+
},
1825+
),
1826+
],
1827+
OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"},
1828+
AlgorithmSpecification=dict(
1829+
TrainingImage=IMAGE,
1830+
TrainingInputMode="File",
1831+
ContainerEntrypoint=[
1832+
"/bin/bash",
1833+
"/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh",
1834+
],
1835+
ContainerArguments=[
1836+
"--s3_base_uri",
1837+
f"{S3_URI}/{job.job_name}",
1838+
"--region",
1839+
TEST_REGION,
1840+
"--client_python_version",
1841+
mock_python_version,
1842+
"--client_sagemaker_pysdk_version",
1843+
mock_sagemaker_pysdk_version,
1844+
"--dependency_settings",
1845+
'{"dependency_file": null}',
1846+
"--run_in_context",
1847+
'{"experiment_name": "my-exp-name", "run_name": "my-run-name"}',
1848+
],
1849+
),
1850+
ResourceConfig=dict(
1851+
VolumeSizeInGB=30,
1852+
InstanceCount=2,
1853+
InstanceType="ml.g5.2xlarge",
1854+
KeepAlivePeriodInSeconds=0,
1855+
),
1856+
EnableNetworkIsolation=False,
1857+
EnableInterContainerTrafficEncryption=True,
1858+
EnableManagedSpotTraining=False,
1859+
Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
1860+
)

0 commit comments

Comments
 (0)