10
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
- """Utility methods used by framework classes"""
13
+ """Utility methods used by framework classes. """
14
14
from __future__ import absolute_import
15
15
16
16
import json
40
40
41
41
UploadedCode = namedtuple ("UploadedCode" , ["s3_prefix" , "script_name" ])
42
42
"""sagemaker.fw_utils.UploadedCode: An object containing the S3 prefix and script name.
43
+
43
44
This is for the source code used for the entry point with an ``Estimator``. It can be
44
45
instantiated with positional or keyword arguments.
45
46
"""
@@ -210,7 +211,7 @@ def validate_source_code_input_against_pipeline_variables(
210
211
git_config : Optional [Dict [str , str ]] = None ,
211
212
enable_network_isolation : Union [bool , PipelineVariable ] = False ,
212
213
):
213
- """Validate source code input against pipeline variables
214
+ """Validate source code input against pipeline variables.
214
215
215
216
Args:
216
217
entry_point (str or PipelineVariable): The path to the local Python source file that
@@ -480,7 +481,7 @@ def tar_and_upload_dir(
480
481
481
482
482
483
def _list_files_to_compress (script , directory ):
483
- """Placeholder docstring"""
484
+ """Placeholder docstring. """
484
485
if directory is None :
485
486
return [script ]
486
487
@@ -619,8 +620,6 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
619
620
"enabled": True
620
621
}
621
622
}
622
-
623
-
624
623
"""
625
624
if training_instance_type == "local" or distribution is None :
626
625
return
@@ -645,7 +644,8 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
645
644
def profiler_config_deprecation_warning (
646
645
profiler_config , image_uri , framework_name , framework_version
647
646
):
648
- """Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >= 2.0"""
647
+ """Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >=
648
+ 2.0."""
649
649
if profiler_config is None or profiler_config .framework_profile_params is None :
650
650
return
651
651
@@ -952,7 +952,7 @@ def validate_distribution(
952
952
953
953
954
954
def validate_distribution_for_instance_type (instance_type , distribution ):
955
- """Check if the provided distribution strategy is supported for the instance_type
955
+ """Check if the provided distribution strategy is supported for the instance_type.
956
956
957
957
Args:
958
958
instance_type (str): A string representing the type of training instance selected.
@@ -1071,7 +1071,7 @@ def validate_torch_distributed_distribution(
1071
1071
1072
1072
1073
1073
def _is_gpu_instance (instance_type ):
1074
- """Returns bool indicating whether instance_type supports GPU
1074
+ """Returns bool indicating whether instance_type supports GPU.
1075
1075
1076
1076
Args:
1077
1077
instance_type (str): Name of the instance_type to check against.
@@ -1090,7 +1090,7 @@ def _is_gpu_instance(instance_type):
1090
1090
1091
1091
1092
1092
def _is_trainium_instance (instance_type ):
1093
- """Returns bool indicating whether instance_type is a Trainium instance
1093
+ """Returns bool indicating whether instance_type is a Trainium instance.
1094
1094
1095
1095
Args:
1096
1096
instance_type (str): Name of the instance_type to check against.
@@ -1106,7 +1106,7 @@ def _is_trainium_instance(instance_type):
1106
1106
1107
1107
1108
1108
def python_deprecation_warning (framework , latest_supported_version ):
1109
- """Placeholder docstring"""
1109
+ """Placeholder docstring. """
1110
1110
return PYTHON_2_DEPRECATION_WARNING .format (
1111
1111
framework = framework , latest_supported_version = latest_supported_version
1112
1112
)
@@ -1120,7 +1120,6 @@ def _region_supports_debugger(region_name):
1120
1120
1121
1121
Returns:
1122
1122
bool: Whether or not the region supports Amazon SageMaker Debugger.
1123
-
1124
1123
"""
1125
1124
return region_name .lower () not in DEBUGGER_UNSUPPORTED_REGIONS
1126
1125
@@ -1133,7 +1132,6 @@ def _region_supports_profiler(region_name):
1133
1132
1134
1133
Returns:
1135
1134
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
1136
-
1137
1135
"""
1138
1136
return region_name .lower () not in PROFILER_UNSUPPORTED_REGIONS
1139
1137
0 commit comments