@@ -389,6 +389,7 @@ def test_start(
389389 s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
390390 s3_kms_key = None ,
391391 sagemaker_session = session (),
392+ use_torchrun = False ,
392393 )
393394
394395 mock_dependency_upload .assert_called_once_with (
@@ -670,6 +671,7 @@ def test_start_with_complete_job_settings(
670671 s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
671672 s3_kms_key = job_settings .s3_kms_key ,
672673 sagemaker_session = session (),
674+ use_torchrun = False ,
673675 )
674676
675677 mock_user_workspace_upload .assert_called_once_with (
@@ -840,6 +842,7 @@ def test_get_train_args_under_pipeline_context(
840842 s3_base_uri = s3_base_uri ,
841843 s3_kms_key = job_settings .s3_kms_key ,
842844 sagemaker_session = session (),
845+ use_torchrun = False ,
843846 )
844847
845848 mock_user_workspace_upload .assert_called_once_with (
@@ -1014,6 +1017,7 @@ def test_start_with_spark(
10141017 s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
10151018 s3_kms_key = None ,
10161019 sagemaker_session = session (),
1020+ use_torchrun = False ,
10171021 )
10181022
10191023 session ().sagemaker_client .create_training_job .assert_called_once_with (
@@ -1601,3 +1605,256 @@ def test_extend_spark_config_to_request(
16011605 }
16021606 ],
16031607 )
1608+
1609+
1610+ @patch ("sagemaker.experiments._run_context._RunContext.get_current_run" , new = mock_get_current_run )
1611+ @patch ("secrets.token_hex" , return_value = HMAC_KEY )
1612+ @patch ("sagemaker.remote_function.job._prepare_and_upload_workspace" , return_value = "some_s3_uri" )
1613+ @patch (
1614+ "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts" , return_value = "some_s3_uri"
1615+ )
1616+ @patch ("sagemaker.remote_function.job.RuntimeEnvironmentManager" )
1617+ @patch ("sagemaker.remote_function.job.StoredFunction" )
1618+ @patch ("sagemaker.remote_function.job.Session" , return_value = mock_session ())
1619+ def test_start_with_torchrun_single_node (
1620+ session ,
1621+ mock_stored_function ,
1622+ mock_runtime_manager ,
1623+ mock_script_upload ,
1624+ mock_dependency_upload ,
1625+ secret_token ,
1626+ ):
1627+
1628+ job_settings = _JobSettings (
1629+ image_uri = IMAGE ,
1630+ s3_root_uri = S3_URI ,
1631+ role = ROLE_ARN ,
1632+ include_local_workdir = True ,
1633+ instance_type = "ml.g5.12xlarge" ,
1634+ encrypt_inter_container_traffic = True ,
1635+ use_torchrun = True ,
1636+ )
1637+
1638+ job = _Job .start (job_settings , job_function , func_args = (1 , 2 ), func_kwargs = {"c" : 3 , "d" : 4 })
1639+
1640+ assert job .job_name .startswith ("job-function" )
1641+
1642+ mock_stored_function .assert_called_once_with (
1643+ sagemaker_session = session (),
1644+ s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
1645+ hmac_key = HMAC_KEY ,
1646+ s3_kms_key = None ,
1647+ )
1648+
1649+ mock_stored_function ().save .assert_called_once_with (job_function , * (1 , 2 ), ** {"c" : 3 , "d" : 4 })
1650+
1651+ local_dependencies_path = mock_runtime_manager ().snapshot ()
1652+ mock_python_version = mock_runtime_manager ()._current_python_version ()
1653+ mock_sagemaker_pysdk_version = mock_runtime_manager ()._current_sagemaker_pysdk_version ()
1654+
1655+ mock_script_upload .assert_called_once_with (
1656+ spark_config = None ,
1657+ s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
1658+ s3_kms_key = None ,
1659+ sagemaker_session = session (),
1660+ use_torchrun = True ,
1661+ )
1662+
1663+ mock_dependency_upload .assert_called_once_with (
1664+ local_dependencies_path = local_dependencies_path ,
1665+ include_local_workdir = True ,
1666+ pre_execution_commands = None ,
1667+ pre_execution_script_local_path = None ,
1668+ s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
1669+ s3_kms_key = None ,
1670+ sagemaker_session = session (),
1671+ custom_file_filter = None ,
1672+ )
1673+
1674+ session ().sagemaker_client .create_training_job .assert_called_once_with (
1675+ TrainingJobName = job .job_name ,
1676+ RoleArn = ROLE_ARN ,
1677+ StoppingCondition = {"MaxRuntimeInSeconds" : 86400 },
1678+ RetryStrategy = {"MaximumRetryAttempts" : 1 },
1679+ InputDataConfig = [
1680+ dict (
1681+ ChannelName = RUNTIME_SCRIPTS_CHANNEL_NAME ,
1682+ DataSource = {
1683+ "S3DataSource" : {
1684+ "S3Uri" : mock_script_upload .return_value ,
1685+ "S3DataType" : "S3Prefix" ,
1686+ }
1687+ },
1688+ ),
1689+ dict (
1690+ ChannelName = REMOTE_FUNCTION_WORKSPACE ,
1691+ DataSource = {
1692+ "S3DataSource" : {
1693+ "S3Uri" : mock_dependency_upload .return_value ,
1694+ "S3DataType" : "S3Prefix" ,
1695+ }
1696+ },
1697+ ),
1698+ ],
1699+ OutputDataConfig = {"S3OutputPath" : f"{ S3_URI } /{ job .job_name } " },
1700+ AlgorithmSpecification = dict (
1701+ TrainingImage = IMAGE ,
1702+ TrainingInputMode = "File" ,
1703+ ContainerEntrypoint = [
1704+ "/bin/bash" ,
1705+ "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh" ,
1706+ ],
1707+ ContainerArguments = [
1708+ "--s3_base_uri" ,
1709+ f"{ S3_URI } /{ job .job_name } " ,
1710+ "--region" ,
1711+ TEST_REGION ,
1712+ "--client_python_version" ,
1713+ mock_python_version ,
1714+ "--client_sagemaker_pysdk_version" ,
1715+ mock_sagemaker_pysdk_version ,
1716+ "--dependency_settings" ,
1717+ '{"dependency_file": null}' ,
1718+ "--run_in_context" ,
1719+ '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}' ,
1720+ ],
1721+ ),
1722+ ResourceConfig = dict (
1723+ VolumeSizeInGB = 30 ,
1724+ InstanceCount = 1 ,
1725+ InstanceType = "ml.g5.12xlarge" ,
1726+ KeepAlivePeriodInSeconds = 0 ,
1727+ ),
1728+ EnableNetworkIsolation = False ,
1729+ EnableInterContainerTrafficEncryption = True ,
1730+ EnableManagedSpotTraining = False ,
1731+ Environment = {"AWS_DEFAULT_REGION" : "us-west-2" , "REMOTE_FUNCTION_SECRET_KEY" : HMAC_KEY },
1732+ )
1733+
1734+
1735+ @patch ("sagemaker.experiments._run_context._RunContext.get_current_run" , new = mock_get_current_run )
1736+ @patch ("secrets.token_hex" , return_value = HMAC_KEY )
1737+ @patch ("sagemaker.remote_function.job._prepare_and_upload_workspace" , return_value = "some_s3_uri" )
1738+ @patch (
1739+ "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts" , return_value = "some_s3_uri"
1740+ )
1741+ @patch ("sagemaker.remote_function.job.RuntimeEnvironmentManager" )
1742+ @patch ("sagemaker.remote_function.job.StoredFunction" )
1743+ @patch ("sagemaker.remote_function.job.Session" , return_value = mock_session ())
1744+ def test_start_with_torchrun_multi_node (
1745+ session ,
1746+ mock_stored_function ,
1747+ mock_runtime_manager ,
1748+ mock_script_upload ,
1749+ mock_dependency_upload ,
1750+ secret_token ,
1751+ ):
1752+
1753+ job_settings = _JobSettings (
1754+ image_uri = IMAGE ,
1755+ s3_root_uri = S3_URI ,
1756+ role = ROLE_ARN ,
1757+ include_local_workdir = True ,
1758+ instance_count = 2 ,
1759+ instance_type = "ml.g5.2xlarge" ,
1760+ encrypt_inter_container_traffic = True ,
1761+ use_torchrun = True ,
1762+ )
1763+
1764+ job = _Job .start (job_settings , job_function , func_args = (1 , 2 ), func_kwargs = {"c" : 3 , "d" : 4 })
1765+
1766+ assert job .job_name .startswith ("job-function" )
1767+
1768+ mock_stored_function .assert_called_once_with (
1769+ sagemaker_session = session (),
1770+ s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
1771+ hmac_key = HMAC_KEY ,
1772+ s3_kms_key = None ,
1773+ )
1774+
1775+ mock_stored_function ().save .assert_called_once_with (job_function , * (1 , 2 ), ** {"c" : 3 , "d" : 4 })
1776+
1777+ local_dependencies_path = mock_runtime_manager ().snapshot ()
1778+ mock_python_version = mock_runtime_manager ()._current_python_version ()
1779+ mock_sagemaker_pysdk_version = mock_runtime_manager ()._current_sagemaker_pysdk_version ()
1780+
1781+ mock_script_upload .assert_called_once_with (
1782+ spark_config = None ,
1783+ s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
1784+ s3_kms_key = None ,
1785+ sagemaker_session = session (),
1786+ use_torchrun = True ,
1787+ )
1788+
1789+ mock_dependency_upload .assert_called_once_with (
1790+ local_dependencies_path = local_dependencies_path ,
1791+ include_local_workdir = True ,
1792+ pre_execution_commands = None ,
1793+ pre_execution_script_local_path = None ,
1794+ s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
1795+ s3_kms_key = None ,
1796+ sagemaker_session = session (),
1797+ custom_file_filter = None ,
1798+ )
1799+
1800+ session ().sagemaker_client .create_training_job .assert_called_once_with (
1801+ TrainingJobName = job .job_name ,
1802+ RoleArn = ROLE_ARN ,
1803+ StoppingCondition = {"MaxRuntimeInSeconds" : 86400 },
1804+ RetryStrategy = {"MaximumRetryAttempts" : 1 },
1805+ InputDataConfig = [
1806+ dict (
1807+ ChannelName = RUNTIME_SCRIPTS_CHANNEL_NAME ,
1808+ DataSource = {
1809+ "S3DataSource" : {
1810+ "S3Uri" : mock_script_upload .return_value ,
1811+ "S3DataType" : "S3Prefix" ,
1812+ "S3DataDistributionType" : "FullyReplicated" ,
1813+ }
1814+ },
1815+ ),
1816+ dict (
1817+ ChannelName = REMOTE_FUNCTION_WORKSPACE ,
1818+ DataSource = {
1819+ "S3DataSource" : {
1820+ "S3Uri" : mock_dependency_upload .return_value ,
1821+ "S3DataType" : "S3Prefix" ,
1822+ "S3DataDistributionType" : "FullyReplicated" ,
1823+ }
1824+ },
1825+ ),
1826+ ],
1827+ OutputDataConfig = {"S3OutputPath" : f"{ S3_URI } /{ job .job_name } " },
1828+ AlgorithmSpecification = dict (
1829+ TrainingImage = IMAGE ,
1830+ TrainingInputMode = "File" ,
1831+ ContainerEntrypoint = [
1832+ "/bin/bash" ,
1833+ "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh" ,
1834+ ],
1835+ ContainerArguments = [
1836+ "--s3_base_uri" ,
1837+ f"{ S3_URI } /{ job .job_name } " ,
1838+ "--region" ,
1839+ TEST_REGION ,
1840+ "--client_python_version" ,
1841+ mock_python_version ,
1842+ "--client_sagemaker_pysdk_version" ,
1843+ mock_sagemaker_pysdk_version ,
1844+ "--dependency_settings" ,
1845+ '{"dependency_file": null}' ,
1846+ "--run_in_context" ,
1847+ '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}' ,
1848+ ],
1849+ ),
1850+ ResourceConfig = dict (
1851+ VolumeSizeInGB = 30 ,
1852+ InstanceCount = 2 ,
1853+ InstanceType = "ml.g5.2xlarge" ,
1854+ KeepAlivePeriodInSeconds = 0 ,
1855+ ),
1856+ EnableNetworkIsolation = False ,
1857+ EnableInterContainerTrafficEncryption = True ,
1858+ EnableManagedSpotTraining = False ,
1859+ Environment = {"AWS_DEFAULT_REGION" : "us-west-2" , "REMOTE_FUNCTION_SECRET_KEY" : HMAC_KEY },
1860+ )
0 commit comments