|
18 | 18 | import inspect |
19 | 19 | import json |
20 | 20 | import logging |
21 | | - |
22 | 21 | from enum import Enum |
23 | | -from typing import Union, Dict, Optional, List, Set |
| 22 | +from typing import Dict, List, Optional, Set, Union |
24 | 23 |
|
25 | 24 | import sagemaker |
26 | 25 | from sagemaker.amazon.amazon_estimator import ( |
27 | | - RecordSet, |
28 | 26 | AmazonAlgorithmEstimatorBase, |
29 | 27 | FileSystemRecordSet, |
| 28 | + RecordSet, |
30 | 29 | ) |
31 | 30 | from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa |
32 | 31 | from sagemaker.analytics import HyperparameterTuningJobAnalytics |
33 | 32 | from sagemaker.deprecations import removed_function |
34 | | -from sagemaker.estimator import Framework, EstimatorBase |
35 | | -from sagemaker.inputs import TrainingInput, FileSystemInput |
| 33 | +from sagemaker.estimator import EstimatorBase, Framework |
| 34 | +from sagemaker.inputs import FileSystemInput, TrainingInput |
36 | 35 | from sagemaker.job import _Job |
37 | 36 | from sagemaker.jumpstart.utils import ( |
38 | 37 | add_jumpstart_uri_tags, |
|
44 | 43 | IntegerParameter, |
45 | 44 | ParameterRange, |
46 | 45 | ) |
47 | | -from sagemaker.workflow.entities import PipelineVariable |
48 | | -from sagemaker.workflow.pipeline_context import runnable_by_pipeline |
49 | | - |
50 | 46 | from sagemaker.session import Session |
51 | 47 | from sagemaker.utils import ( |
| 48 | + Tags, |
52 | 49 | base_from_name, |
53 | 50 | base_name_from_image, |
| 51 | + format_tags, |
54 | 52 | name_from_base, |
55 | 53 | to_string, |
56 | | - format_tags, |
57 | | - Tags, |
58 | 54 | ) |
| 55 | +from sagemaker.workflow.entities import PipelineVariable |
| 56 | +from sagemaker.workflow.pipeline_context import runnable_by_pipeline |
59 | 57 |
|
60 | 58 | AMAZON_ESTIMATOR_MODULE = "sagemaker" |
61 | 59 | AMAZON_ESTIMATOR_CLS_NAMES = { |
@@ -133,15 +131,12 @@ def __init__( |
133 | 131 |
|
134 | 132 | if warm_start_type not in list(WarmStartTypes): |
135 | 133 | raise ValueError( |
136 | | - "Invalid type: {}, valid warm start types are: {}".format( |
137 | | - warm_start_type, list(WarmStartTypes) |
138 | | - ) |
| 134 | + f"Invalid type: {warm_start_type}, " |
| 135 | + f"valid warm start types are: {list(WarmStartTypes)}" |
139 | 136 | ) |
140 | 137 |
|
141 | 138 | if not parents: |
142 | | - raise ValueError( |
143 | | - "Invalid parents: {}, parents should not be None/empty".format(parents) |
144 | | - ) |
| 139 | + raise ValueError(f"Invalid parents: {parents}, parents should not be None/empty") |
145 | 140 |
|
146 | 141 | self.type = warm_start_type |
147 | 142 | self.parents = set(parents) |
@@ -1455,9 +1450,7 @@ def _get_best_training_job(self): |
1455 | 1450 | return tuning_job_describe_result["BestTrainingJob"] |
1456 | 1451 | except KeyError: |
1457 | 1452 | raise Exception( |
1458 | | - "Best training job not available for tuning job: {}".format( |
1459 | | - self.latest_tuning_job.name |
1460 | | - ) |
| 1453 | + f"Best training job not available for tuning job: {self.latest_tuning_job.name}" |
1461 | 1454 | ) |
1462 | 1455 |
|
1463 | 1456 | def _ensure_last_tuning_job(self): |
@@ -1920,8 +1913,11 @@ def create( |
1920 | 1913 | :meth:`~sagemaker.tuner.HyperparameterTuner.fit` method launches. |
1921 | 1914 | If not specified, a default job name is generated, |
1922 | 1915 | based on the training image name and current timestamp. |
1923 | | - strategy (str): Strategy to be used for hyperparameter estimations |
1924 | | - (default: 'Bayesian'). |
| 1916 | + strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations. |
| 1917 | + More information about different strategies: |
| 1918 | + https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html. |
| 1919 | + Available options are: 'Bayesian', 'Random', 'Hyperband', |
| 1920 | + 'Grid' (default: 'Bayesian') |
1925 | 1921 | strategy_config (dict): The configuration for a training job launched by a |
1926 | 1922 | hyperparameter tuning job. |
1927 | 1923 | completion_criteria_config (dict): The configuration for tuning job completion criteria. |
@@ -2080,21 +2076,19 @@ def _validate_dict_argument(cls, name, value, allowed_keys, require_same_keys=Fa |
2080 | 2076 | return |
2081 | 2077 |
|
2082 | 2078 | if not isinstance(value, dict): |
2083 | | - raise ValueError( |
2084 | | - "Argument '{}' must be a dictionary using {} as keys".format(name, allowed_keys) |
2085 | | - ) |
| 2079 | + raise ValueError(f"Argument '{name}' must be a dictionary using {allowed_keys} as keys") |
2086 | 2080 |
|
2087 | 2081 | value_keys = sorted(value.keys()) |
2088 | 2082 |
|
2089 | 2083 | if require_same_keys: |
2090 | 2084 | if value_keys != allowed_keys: |
2091 | 2085 | raise ValueError( |
2092 | | - "The keys of argument '{}' must be the same as {}".format(name, allowed_keys) |
| 2086 | + f"The keys of argument '{name}' must be the same as {allowed_keys}" |
2093 | 2087 | ) |
2094 | 2088 | else: |
2095 | 2089 | if not set(value_keys).issubset(set(allowed_keys)): |
2096 | 2090 | raise ValueError( |
2097 | | - "The keys of argument '{}' must be a subset of {}".format(name, allowed_keys) |
| 2091 | + f"The keys of argument '{name}' must be a subset of {allowed_keys}" |
2098 | 2092 | ) |
2099 | 2093 |
|
2100 | 2094 | def _add_estimator( |
|
0 commit comments