1010# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
13- """Utility methods used by framework classes"""
13+ """Utility methods used by framework classes. """
1414from __future__ import absolute_import
1515
1616import json
1717import logging
1818import os
1919import re
20- import time
2120import shutil
2221import tempfile
22+ import time
2323from collections import namedtuple
24- from typing import List , Optional , Union , Dict
24+ from typing import Dict , List , Optional , Union
25+
2526from packaging import version
2627
2728import sagemaker .image_uris
29+ import sagemaker .utils
30+ from sagemaker .deprecations import deprecation_warn_base , renamed_kwargs , renamed_warning
2831from sagemaker .instance_group import InstanceGroup
2932from sagemaker .s3_utils import s3_path_join
3033from sagemaker .session_settings import SessionSettings
31- import sagemaker .utils
3234from sagemaker .workflow import is_pipeline_variable
33-
34- from sagemaker .deprecations import renamed_warning , renamed_kwargs
3535from sagemaker .workflow .entities import PipelineVariable
36- from sagemaker .deprecations import deprecation_warn_base
3736
3837logger = logging .getLogger (__name__ )
3938
4039_TAR_SOURCE_FILENAME = "source.tar.gz"
4140
4241UploadedCode = namedtuple ("UploadedCode" , ["s3_prefix" , "script_name" ])
4342"""sagemaker.fw_utils.UploadedCode: An object containing the S3 prefix and script name.
43+
4444This is for the source code used for the entry point with an ``Estimator``. It can be
4545instantiated with positional or keyword arguments.
4646"""
@@ -211,7 +211,7 @@ def validate_source_code_input_against_pipeline_variables(
211211 git_config : Optional [Dict [str , str ]] = None ,
212212 enable_network_isolation : Union [bool , PipelineVariable ] = False ,
213213):
214- """Validate source code input against pipeline variables
214+ """Validate source code input against pipeline variables.
215215
216216 Args:
217217 entry_point (str or PipelineVariable): The path to the local Python source file that
@@ -481,7 +481,7 @@ def tar_and_upload_dir(
481481
482482
483483def _list_files_to_compress (script , directory ):
484- """Placeholder docstring"""
484+ """Placeholder docstring. """
485485 if directory is None :
486486 return [script ]
487487
@@ -585,7 +585,6 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image):
585585 The location returned is a potential concatenation of 2 parts
586586 1. code_location_key_prefix if it exists
587587 2. model_name or a name derived from the image
588-
589588 Args:
590589 code_location_key_prefix (str): the s3 key prefix from code_location
591590 model_name (str): the name of the model
@@ -620,8 +619,6 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
620619 "enabled": True
621620 }
622621 }
623-
624-
625622 """
626623 if training_instance_type == "local" or distribution is None :
627624 return
@@ -646,7 +643,7 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
646643def profiler_config_deprecation_warning (
647644 profiler_config , image_uri , framework_name , framework_version
648645):
649- """Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >= 2.0"""
646+ """Deprecation message if framework profiling is specified TF >= 2.12 and PT >= 2.0. """
650647 if profiler_config is None or profiler_config .framework_profile_params is None :
651648 return
652649
@@ -692,6 +689,7 @@ def validate_smdistributed(
692689 framework_name (str): A string representing the name of framework selected.
693690 framework_version (str): A string representing the framework version selected.
694691 py_version (str): A string representing the python version selected.
692+ Ex: `py38, py39, py310, py311`
695693 distribution (dict): A dictionary with information to enable distributed training.
696694 (Defaults to None if distributed training is not enabled.) For example:
697695
@@ -763,7 +761,8 @@ def _validate_smdataparallel_args(
763761 instance_type (str): A string representing the type of training instance selected. Ex: `ml.p3.16xlarge`
764762 framework_name (str): A string representing the name of framework selected. Ex: `tensorflow`
765763 framework_version (str): A string representing the framework version selected. Ex: `2.3.1`
766- py_version (str): A string representing the python version selected. Ex: `py3`
764+ py_version (str): A string representing the python version selected.
765+ Ex: `py38, py39, py310, py311`
767766 distribution (dict): A dictionary with information to enable distributed training.
768767 (Defaults to None if distributed training is not enabled.) Ex:
769768
@@ -847,6 +846,7 @@ def validate_distribution(
847846 framework_name (str): A string representing the name of framework selected.
848847 framework_version (str): A string representing the framework version selected.
849848 py_version (str): A string representing the python version selected.
849+ Ex: `py38, py39, py310, py311`
850850 image_uri (str): A string representing a Docker image URI.
851851 kwargs(dict): Additional kwargs passed to this function
852852
@@ -953,7 +953,7 @@ def validate_distribution(
953953
954954
955955def validate_distribution_for_instance_type (instance_type , distribution ):
956- """Check if the provided distribution strategy is supported for the instance_type
956+ """Check if the provided distribution strategy is supported for the instance_type.
957957
958958 Args:
959959 instance_type (str): A string representing the type of training instance selected.
@@ -1010,6 +1010,7 @@ def validate_torch_distributed_distribution(
10101010 }
10111011 framework_version (str): A string representing the framework version selected.
10121012 py_version (str): A string representing the python version selected.
1013+ Ex: `py38, py39, py310, py311`
10131014 image_uri (str): A string representing a Docker image URI.
10141015 entry_point (str or PipelineVariable): The absolute or relative path to the local Python
10151016 source file that should be executed as the entry point to
@@ -1072,7 +1073,7 @@ def validate_torch_distributed_distribution(
10721073
10731074
10741075def _is_gpu_instance (instance_type ):
1075- """Returns bool indicating whether instance_type supports GPU
1076+ """Returns bool indicating whether instance_type supports GPU.
10761077
10771078 Args:
10781079 instance_type (str): Name of the instance_type to check against.
@@ -1091,7 +1092,7 @@ def _is_gpu_instance(instance_type):
10911092
10921093
10931094def _is_trainium_instance (instance_type ):
1094- """Returns bool indicating whether instance_type is a Trainium instance
1095+ """Returns bool indicating whether instance_type is a Trainium instance.
10951096
10961097 Args:
10971098 instance_type (str): Name of the instance_type to check against.
@@ -1107,7 +1108,7 @@ def _is_trainium_instance(instance_type):
11071108
11081109
11091110def python_deprecation_warning (framework , latest_supported_version ):
1110- """Placeholder docstring"""
1111+ """Placeholder docstring. """
11111112 return PYTHON_2_DEPRECATION_WARNING .format (
11121113 framework = framework , latest_supported_version = latest_supported_version
11131114 )
@@ -1121,7 +1122,6 @@ def _region_supports_debugger(region_name):
11211122
11221123 Returns:
11231124 bool: Whether or not the region supports Amazon SageMaker Debugger.
1124-
11251125 """
11261126 return region_name .lower () not in DEBUGGER_UNSUPPORTED_REGIONS
11271127
@@ -1134,7 +1134,6 @@ def _region_supports_profiler(region_name):
11341134
11351135 Returns:
11361136 bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
1137-
11381137 """
11391138 return region_name .lower () not in PROFILER_UNSUPPORTED_REGIONS
11401139
@@ -1162,7 +1161,8 @@ def validate_version_or_image_args(framework_version, py_version, image_uri):
11621161
11631162 Args:
11641163 framework_version (str): The version of the framework.
1165- py_version (str): The version of Python.
1164+ py_version (str): A string representing the python version selected.
1165+ Ex: `py38, py39, py310, py311`
11661166 image_uri (str): The URI of the image.
11671167
11681168 Raises:
@@ -1194,9 +1194,8 @@ def create_image_uri(
11941194 instance_type (str): SageMaker instance type. Used to determine device
11951195 type (cpu/gpu/family-specific optimized).
11961196 framework_version (str): The version of the framework.
1197- py_version (str): Optional. Python version. If specified, should be one
1198- of 'py2' or 'py3'. If not specified, image uri will not include a
1199- python component.
1197+ py_version (str): Optional. Python version Ex: `py38, py39, py310, py311`.
1198+ If not specified, image uri will not include a python component.
12001199 account (str): AWS account that contains the image. (default:
12011200 '520713654638')
12021201 accelerator_type (str): SageMaker Elastic Inference accelerator type.
0 commit comments