@@ -389,6 +389,7 @@ def test_start(
389
389
s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
390
390
s3_kms_key = None ,
391
391
sagemaker_session = session (),
392
+ use_torchrun = False ,
392
393
)
393
394
394
395
mock_dependency_upload .assert_called_once_with (
@@ -670,6 +671,7 @@ def test_start_with_complete_job_settings(
670
671
s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
671
672
s3_kms_key = job_settings .s3_kms_key ,
672
673
sagemaker_session = session (),
674
+ use_torchrun = False ,
673
675
)
674
676
675
677
mock_user_workspace_upload .assert_called_once_with (
@@ -840,6 +842,7 @@ def test_get_train_args_under_pipeline_context(
840
842
s3_base_uri = s3_base_uri ,
841
843
s3_kms_key = job_settings .s3_kms_key ,
842
844
sagemaker_session = session (),
845
+ use_torchrun = False ,
843
846
)
844
847
845
848
mock_user_workspace_upload .assert_called_once_with (
@@ -1014,6 +1017,7 @@ def test_start_with_spark(
1014
1017
s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
1015
1018
s3_kms_key = None ,
1016
1019
sagemaker_session = session (),
1020
+ use_torchrun = False ,
1017
1021
)
1018
1022
1019
1023
session ().sagemaker_client .create_training_job .assert_called_once_with (
@@ -1601,3 +1605,256 @@ def test_extend_spark_config_to_request(
1601
1605
}
1602
1606
],
1603
1607
)
1608
+
1609
+
1610
+ @patch ("sagemaker.experiments._run_context._RunContext.get_current_run" , new = mock_get_current_run )
1611
+ @patch ("secrets.token_hex" , return_value = HMAC_KEY )
1612
+ @patch ("sagemaker.remote_function.job._prepare_and_upload_workspace" , return_value = "some_s3_uri" )
1613
+ @patch (
1614
+ "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts" , return_value = "some_s3_uri"
1615
+ )
1616
+ @patch ("sagemaker.remote_function.job.RuntimeEnvironmentManager" )
1617
+ @patch ("sagemaker.remote_function.job.StoredFunction" )
1618
+ @patch ("sagemaker.remote_function.job.Session" , return_value = mock_session ())
1619
+ def test_start_with_torchrun_single_node (
1620
+ session ,
1621
+ mock_stored_function ,
1622
+ mock_runtime_manager ,
1623
+ mock_script_upload ,
1624
+ mock_dependency_upload ,
1625
+ secret_token ,
1626
+ ):
1627
+
1628
+ job_settings = _JobSettings (
1629
+ image_uri = IMAGE ,
1630
+ s3_root_uri = S3_URI ,
1631
+ role = ROLE_ARN ,
1632
+ include_local_workdir = True ,
1633
+ instance_type = "ml.g5.12xlarge" ,
1634
+ encrypt_inter_container_traffic = True ,
1635
+ use_torchrun = True ,
1636
+ )
1637
+
1638
+ job = _Job .start (job_settings , job_function , func_args = (1 , 2 ), func_kwargs = {"c" : 3 , "d" : 4 })
1639
+
1640
+ assert job .job_name .startswith ("job-function" )
1641
+
1642
+ mock_stored_function .assert_called_once_with (
1643
+ sagemaker_session = session (),
1644
+ s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
1645
+ hmac_key = HMAC_KEY ,
1646
+ s3_kms_key = None ,
1647
+ )
1648
+
1649
+ mock_stored_function ().save .assert_called_once_with (job_function , * (1 , 2 ), ** {"c" : 3 , "d" : 4 })
1650
+
1651
+ local_dependencies_path = mock_runtime_manager ().snapshot ()
1652
+ mock_python_version = mock_runtime_manager ()._current_python_version ()
1653
+ mock_sagemaker_pysdk_version = mock_runtime_manager ()._current_sagemaker_pysdk_version ()
1654
+
1655
+ mock_script_upload .assert_called_once_with (
1656
+ spark_config = None ,
1657
+ s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
1658
+ s3_kms_key = None ,
1659
+ sagemaker_session = session (),
1660
+ use_torchrun = True ,
1661
+ )
1662
+
1663
+ mock_dependency_upload .assert_called_once_with (
1664
+ local_dependencies_path = local_dependencies_path ,
1665
+ include_local_workdir = True ,
1666
+ pre_execution_commands = None ,
1667
+ pre_execution_script_local_path = None ,
1668
+ s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
1669
+ s3_kms_key = None ,
1670
+ sagemaker_session = session (),
1671
+ custom_file_filter = None ,
1672
+ )
1673
+
1674
+ session ().sagemaker_client .create_training_job .assert_called_once_with (
1675
+ TrainingJobName = job .job_name ,
1676
+ RoleArn = ROLE_ARN ,
1677
+ StoppingCondition = {"MaxRuntimeInSeconds" : 86400 },
1678
+ RetryStrategy = {"MaximumRetryAttempts" : 1 },
1679
+ InputDataConfig = [
1680
+ dict (
1681
+ ChannelName = RUNTIME_SCRIPTS_CHANNEL_NAME ,
1682
+ DataSource = {
1683
+ "S3DataSource" : {
1684
+ "S3Uri" : mock_script_upload .return_value ,
1685
+ "S3DataType" : "S3Prefix" ,
1686
+ }
1687
+ },
1688
+ ),
1689
+ dict (
1690
+ ChannelName = REMOTE_FUNCTION_WORKSPACE ,
1691
+ DataSource = {
1692
+ "S3DataSource" : {
1693
+ "S3Uri" : mock_dependency_upload .return_value ,
1694
+ "S3DataType" : "S3Prefix" ,
1695
+ }
1696
+ },
1697
+ ),
1698
+ ],
1699
+ OutputDataConfig = {"S3OutputPath" : f"{ S3_URI } /{ job .job_name } " },
1700
+ AlgorithmSpecification = dict (
1701
+ TrainingImage = IMAGE ,
1702
+ TrainingInputMode = "File" ,
1703
+ ContainerEntrypoint = [
1704
+ "/bin/bash" ,
1705
+ "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh" ,
1706
+ ],
1707
+ ContainerArguments = [
1708
+ "--s3_base_uri" ,
1709
+ f"{ S3_URI } /{ job .job_name } " ,
1710
+ "--region" ,
1711
+ TEST_REGION ,
1712
+ "--client_python_version" ,
1713
+ mock_python_version ,
1714
+ "--client_sagemaker_pysdk_version" ,
1715
+ mock_sagemaker_pysdk_version ,
1716
+ "--dependency_settings" ,
1717
+ '{"dependency_file": null}' ,
1718
+ "--run_in_context" ,
1719
+ '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}' ,
1720
+ ],
1721
+ ),
1722
+ ResourceConfig = dict (
1723
+ VolumeSizeInGB = 30 ,
1724
+ InstanceCount = 1 ,
1725
+ InstanceType = "ml.g5.12xlarge" ,
1726
+ KeepAlivePeriodInSeconds = 0 ,
1727
+ ),
1728
+ EnableNetworkIsolation = False ,
1729
+ EnableInterContainerTrafficEncryption = True ,
1730
+ EnableManagedSpotTraining = False ,
1731
+ Environment = {"AWS_DEFAULT_REGION" : "us-west-2" , "REMOTE_FUNCTION_SECRET_KEY" : HMAC_KEY },
1732
+ )
1733
+
1734
+
1735
+ @patch ("sagemaker.experiments._run_context._RunContext.get_current_run" , new = mock_get_current_run )
1736
+ @patch ("secrets.token_hex" , return_value = HMAC_KEY )
1737
+ @patch ("sagemaker.remote_function.job._prepare_and_upload_workspace" , return_value = "some_s3_uri" )
1738
+ @patch (
1739
+ "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts" , return_value = "some_s3_uri"
1740
+ )
1741
+ @patch ("sagemaker.remote_function.job.RuntimeEnvironmentManager" )
1742
+ @patch ("sagemaker.remote_function.job.StoredFunction" )
1743
+ @patch ("sagemaker.remote_function.job.Session" , return_value = mock_session ())
1744
+ def test_start_with_torchrun_multi_node (
1745
+ session ,
1746
+ mock_stored_function ,
1747
+ mock_runtime_manager ,
1748
+ mock_script_upload ,
1749
+ mock_dependency_upload ,
1750
+ secret_token ,
1751
+ ):
1752
+
1753
+ job_settings = _JobSettings (
1754
+ image_uri = IMAGE ,
1755
+ s3_root_uri = S3_URI ,
1756
+ role = ROLE_ARN ,
1757
+ include_local_workdir = True ,
1758
+ instance_count = 2 ,
1759
+ instance_type = "ml.g5.2xlarge" ,
1760
+ encrypt_inter_container_traffic = True ,
1761
+ use_torchrun = True ,
1762
+ )
1763
+
1764
+ job = _Job .start (job_settings , job_function , func_args = (1 , 2 ), func_kwargs = {"c" : 3 , "d" : 4 })
1765
+
1766
+ assert job .job_name .startswith ("job-function" )
1767
+
1768
+ mock_stored_function .assert_called_once_with (
1769
+ sagemaker_session = session (),
1770
+ s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
1771
+ hmac_key = HMAC_KEY ,
1772
+ s3_kms_key = None ,
1773
+ )
1774
+
1775
+ mock_stored_function ().save .assert_called_once_with (job_function , * (1 , 2 ), ** {"c" : 3 , "d" : 4 })
1776
+
1777
+ local_dependencies_path = mock_runtime_manager ().snapshot ()
1778
+ mock_python_version = mock_runtime_manager ()._current_python_version ()
1779
+ mock_sagemaker_pysdk_version = mock_runtime_manager ()._current_sagemaker_pysdk_version ()
1780
+
1781
+ mock_script_upload .assert_called_once_with (
1782
+ spark_config = None ,
1783
+ s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
1784
+ s3_kms_key = None ,
1785
+ sagemaker_session = session (),
1786
+ use_torchrun = True ,
1787
+ )
1788
+
1789
+ mock_dependency_upload .assert_called_once_with (
1790
+ local_dependencies_path = local_dependencies_path ,
1791
+ include_local_workdir = True ,
1792
+ pre_execution_commands = None ,
1793
+ pre_execution_script_local_path = None ,
1794
+ s3_base_uri = f"{ S3_URI } /{ job .job_name } " ,
1795
+ s3_kms_key = None ,
1796
+ sagemaker_session = session (),
1797
+ custom_file_filter = None ,
1798
+ )
1799
+
1800
+ session ().sagemaker_client .create_training_job .assert_called_once_with (
1801
+ TrainingJobName = job .job_name ,
1802
+ RoleArn = ROLE_ARN ,
1803
+ StoppingCondition = {"MaxRuntimeInSeconds" : 86400 },
1804
+ RetryStrategy = {"MaximumRetryAttempts" : 1 },
1805
+ InputDataConfig = [
1806
+ dict (
1807
+ ChannelName = RUNTIME_SCRIPTS_CHANNEL_NAME ,
1808
+ DataSource = {
1809
+ "S3DataSource" : {
1810
+ "S3Uri" : mock_script_upload .return_value ,
1811
+ "S3DataType" : "S3Prefix" ,
1812
+ "S3DataDistributionType" : "FullyReplicated" ,
1813
+ }
1814
+ },
1815
+ ),
1816
+ dict (
1817
+ ChannelName = REMOTE_FUNCTION_WORKSPACE ,
1818
+ DataSource = {
1819
+ "S3DataSource" : {
1820
+ "S3Uri" : mock_dependency_upload .return_value ,
1821
+ "S3DataType" : "S3Prefix" ,
1822
+ "S3DataDistributionType" : "FullyReplicated" ,
1823
+ }
1824
+ },
1825
+ ),
1826
+ ],
1827
+ OutputDataConfig = {"S3OutputPath" : f"{ S3_URI } /{ job .job_name } " },
1828
+ AlgorithmSpecification = dict (
1829
+ TrainingImage = IMAGE ,
1830
+ TrainingInputMode = "File" ,
1831
+ ContainerEntrypoint = [
1832
+ "/bin/bash" ,
1833
+ "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh" ,
1834
+ ],
1835
+ ContainerArguments = [
1836
+ "--s3_base_uri" ,
1837
+ f"{ S3_URI } /{ job .job_name } " ,
1838
+ "--region" ,
1839
+ TEST_REGION ,
1840
+ "--client_python_version" ,
1841
+ mock_python_version ,
1842
+ "--client_sagemaker_pysdk_version" ,
1843
+ mock_sagemaker_pysdk_version ,
1844
+ "--dependency_settings" ,
1845
+ '{"dependency_file": null}' ,
1846
+ "--run_in_context" ,
1847
+ '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}' ,
1848
+ ],
1849
+ ),
1850
+ ResourceConfig = dict (
1851
+ VolumeSizeInGB = 30 ,
1852
+ InstanceCount = 2 ,
1853
+ InstanceType = "ml.g5.2xlarge" ,
1854
+ KeepAlivePeriodInSeconds = 0 ,
1855
+ ),
1856
+ EnableNetworkIsolation = False ,
1857
+ EnableInterContainerTrafficEncryption = True ,
1858
+ EnableManagedSpotTraining = False ,
1859
+ Environment = {"AWS_DEFAULT_REGION" : "us-west-2" , "REMOTE_FUNCTION_SECRET_KEY" : HMAC_KEY },
1860
+ )
0 commit comments