Skip to content

Commit 96e9f1c

Browse files
committed
More formatting
1 parent e4baf4d commit 96e9f1c

File tree

3 files changed

+34
-40
lines changed

3 files changed

+34
-40
lines changed

src/sagemaker/fw_utils.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Utility methods used by framework classes"""
13+
"""Utility methods used by framework classes."""
1414
from __future__ import absolute_import
1515

1616
import json
@@ -40,6 +40,7 @@
4040

4141
UploadedCode = namedtuple("UploadedCode", ["s3_prefix", "script_name"])
4242
"""sagemaker.fw_utils.UploadedCode: An object containing the S3 prefix and script name.
43+
4344
This is for the source code used for the entry point with an ``Estimator``. It can be
4445
instantiated with positional or keyword arguments.
4546
"""
@@ -210,7 +211,7 @@ def validate_source_code_input_against_pipeline_variables(
210211
git_config: Optional[Dict[str, str]] = None,
211212
enable_network_isolation: Union[bool, PipelineVariable] = False,
212213
):
213-
"""Validate source code input against pipeline variables
214+
"""Validate source code input against pipeline variables.
214215
215216
Args:
216217
entry_point (str or PipelineVariable): The path to the local Python source file that
@@ -480,7 +481,7 @@ def tar_and_upload_dir(
480481

481482

482483
def _list_files_to_compress(script, directory):
483-
"""Placeholder docstring"""
484+
"""Placeholder docstring."""
484485
if directory is None:
485486
return [script]
486487

@@ -619,8 +620,6 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
619620
"enabled": True
620621
}
621622
}
622-
623-
624623
"""
625624
if training_instance_type == "local" or distribution is None:
626625
return
@@ -645,7 +644,8 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
645644
def profiler_config_deprecation_warning(
646645
profiler_config, image_uri, framework_name, framework_version
647646
):
648-
"""Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >= 2.0"""
647+
"""Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >=
648+
2.0."""
649649
if profiler_config is None or profiler_config.framework_profile_params is None:
650650
return
651651

@@ -952,7 +952,7 @@ def validate_distribution(
952952

953953

954954
def validate_distribution_for_instance_type(instance_type, distribution):
955-
"""Check if the provided distribution strategy is supported for the instance_type
955+
"""Check if the provided distribution strategy is supported for the instance_type.
956956
957957
Args:
958958
instance_type (str): A string representing the type of training instance selected.
@@ -1071,7 +1071,7 @@ def validate_torch_distributed_distribution(
10711071

10721072

10731073
def _is_gpu_instance(instance_type):
1074-
"""Returns bool indicating whether instance_type supports GPU
1074+
"""Returns bool indicating whether instance_type supports GPU.
10751075
10761076
Args:
10771077
instance_type (str): Name of the instance_type to check against.
@@ -1090,7 +1090,7 @@ def _is_gpu_instance(instance_type):
10901090

10911091

10921092
def _is_trainium_instance(instance_type):
1093-
"""Returns bool indicating whether instance_type is a Trainium instance
1093+
"""Returns bool indicating whether instance_type is a Trainium instance.
10941094
10951095
Args:
10961096
instance_type (str): Name of the instance_type to check against.
@@ -1106,7 +1106,7 @@ def _is_trainium_instance(instance_type):
11061106

11071107

11081108
def python_deprecation_warning(framework, latest_supported_version):
1109-
"""Placeholder docstring"""
1109+
"""Placeholder docstring."""
11101110
return PYTHON_2_DEPRECATION_WARNING.format(
11111111
framework=framework, latest_supported_version=latest_supported_version
11121112
)
@@ -1120,7 +1120,6 @@ def _region_supports_debugger(region_name):
11201120
11211121
Returns:
11221122
bool: Whether or not the region supports Amazon SageMaker Debugger.
1123-
11241123
"""
11251124
return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS
11261125

@@ -1133,7 +1132,6 @@ def _region_supports_profiler(region_name):
11331132
11341133
Returns:
11351134
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
1136-
11371135
"""
11381136
return region_name.lower() not in PROFILER_UNSUPPORTED_REGIONS
11391137

src/sagemaker/huggingface/estimator.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,13 @@
1515

1616
import logging
1717
import re
18-
from typing import Optional, Union, Dict
18+
from typing import Dict, Optional, Union
1919

20-
from sagemaker.estimator import Framework, EstimatorBase
21-
from sagemaker.fw_utils import (
22-
framework_name_from_image,
23-
validate_distribution,
24-
)
20+
from sagemaker.estimator import EstimatorBase, Framework
21+
from sagemaker.fw_utils import framework_name_from_image, validate_distribution
2522
from sagemaker.huggingface.model import HuggingFaceModel
26-
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
27-
2823
from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig
24+
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2925
from sagemaker.workflow.entities import PipelineVariable
3026

3127
logger = logging.getLogger("sagemaker")
@@ -66,7 +62,7 @@ def __init__(
6662
Args:
6763
py_version (str): Python version you want to use for executing your model training
6864
code. Defaults to ``None``. Required unless ``image_uri`` is provided. If
69-
using PyTorch, the current supported version is ``py36``. If using TensorFlow,
65+
using PyTorch, the current supported version is ``py39``. If using TensorFlow,
7066
the current supported version is ``py37``.
7167
entry_point (str or PipelineVariable): Path (absolute or relative) to the Python source
7268
file which should be executed as the entry point to training.

src/sagemaker/processing.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,51 +18,51 @@
1818
"""
1919
from __future__ import absolute_import
2020

21+
import logging
2122
import os
2223
import pathlib
23-
import logging
24+
import re
25+
from copy import copy
2426
from textwrap import dedent
2527
from typing import Dict, List, Optional, Union
26-
from copy import copy
27-
import re
2828

2929
import attr
30-
3130
from six.moves.urllib.parse import urlparse
3231
from six.moves.urllib.request import url2pathname
32+
3333
from sagemaker import s3
34+
from sagemaker.apiutils._base_types import ApiObject
3435
from sagemaker.config import (
36+
PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
37+
PROCESSING_JOB_ENVIRONMENT_PATH,
38+
PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
3539
PROCESSING_JOB_KMS_KEY_ID_PATH,
40+
PROCESSING_JOB_ROLE_ARN_PATH,
3641
PROCESSING_JOB_SECURITY_GROUP_IDS_PATH,
3742
PROCESSING_JOB_SUBNETS_PATH,
38-
PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
3943
PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH,
40-
PROCESSING_JOB_ROLE_ARN_PATH,
41-
PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
42-
PROCESSING_JOB_ENVIRONMENT_PATH,
4344
)
45+
from sagemaker.dataset_definition.inputs import DatasetDefinition, S3Input
4446
from sagemaker.job import _Job
4547
from sagemaker.local import LocalSession
4648
from sagemaker.network import NetworkConfig
49+
from sagemaker.s3 import S3Uploader
50+
from sagemaker.session import Session
4751
from sagemaker.utils import (
52+
Tags,
4853
base_name_from_image,
54+
check_and_get_run_experiment_config,
55+
format_tags,
4956
get_config_value,
5057
name_from_base,
51-
check_and_get_run_experiment_config,
52-
resolve_value_from_config,
5358
resolve_class_attribute_from_config,
54-
Tags,
55-
format_tags,
59+
resolve_value_from_config,
5660
)
57-
from sagemaker.session import Session
5861
from sagemaker.workflow import is_pipeline_variable
62+
from sagemaker.workflow.entities import PipelineVariable
63+
from sagemaker.workflow.execution_variables import ExecutionVariables
5964
from sagemaker.workflow.functions import Join
6065
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
61-
from sagemaker.workflow.execution_variables import ExecutionVariables
62-
from sagemaker.workflow.entities import PipelineVariable
63-
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
64-
from sagemaker.apiutils._base_types import ApiObject
65-
from sagemaker.s3 import S3Uploader
6666

6767
logger = logging.getLogger(__name__)
6868

@@ -1465,7 +1465,7 @@ def __init__(
14651465
instance_type (str or PipelineVariable): The type of EC2 instance to use for
14661466
processing, for example, 'ml.c4.xlarge'.
14671467
py_version (str): Python version you want to use for executing your
1468-
model training code. One of 'py2' or 'py3'. Defaults to 'py3'. Value
1468+
model training code. Ex `py38, py39, py310, py311`. Value
14691469
is ignored when ``image_uri`` is provided.
14701470
image_uri (str or PipelineVariable): The URI of the Docker image to use for the
14711471
processing jobs (default: None).

0 commit comments

Comments
 (0)