diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 4b0f38f36f..fa8f9b8555 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -18,21 +18,20 @@ import inspect import json import logging - from enum import Enum -from typing import Union, Dict, Optional, List, Set +from typing import Dict, List, Optional, Set, Union import sagemaker from sagemaker.amazon.amazon_estimator import ( - RecordSet, AmazonAlgorithmEstimatorBase, FileSystemRecordSet, + RecordSet, ) from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.analytics import HyperparameterTuningJobAnalytics from sagemaker.deprecations import removed_function -from sagemaker.estimator import Framework, EstimatorBase -from sagemaker.inputs import TrainingInput, FileSystemInput +from sagemaker.estimator import EstimatorBase, Framework +from sagemaker.inputs import FileSystemInput, TrainingInput from sagemaker.job import _Job from sagemaker.jumpstart.utils import ( add_jumpstart_uri_tags, @@ -44,18 +43,17 @@ IntegerParameter, ParameterRange, ) -from sagemaker.workflow.entities import PipelineVariable -from sagemaker.workflow.pipeline_context import runnable_by_pipeline - from sagemaker.session import Session from sagemaker.utils import ( + Tags, base_from_name, base_name_from_image, + format_tags, name_from_base, to_string, - format_tags, - Tags, ) +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.pipeline_context import runnable_by_pipeline AMAZON_ESTIMATOR_MODULE = "sagemaker" AMAZON_ESTIMATOR_CLS_NAMES = { @@ -133,15 +131,12 @@ def __init__( if warm_start_type not in list(WarmStartTypes): raise ValueError( - "Invalid type: {}, valid warm start types are: {}".format( - warm_start_type, list(WarmStartTypes) - ) + f"Invalid type: {warm_start_type}, " + f"valid warm start types are: {list(WarmStartTypes)}" ) if not parents: - raise ValueError( - "Invalid parents: {}, parents should not be None/empty".format(parents) - ) + raise ValueError(f"Invalid parents: {parents}, parents should not be None/empty") self.type = warm_start_type self.parents = set(parents) @@ -1455,9 +1450,7 @@ def _get_best_training_job(self): return tuning_job_describe_result["BestTrainingJob"] except KeyError: raise Exception( - "Best training job not available for tuning job: {}".format( - self.latest_tuning_job.name - ) + f"Best training job not available for tuning job: {self.latest_tuning_job.name}" ) def _ensure_last_tuning_job(self): @@ -1920,8 +1913,11 @@ def create( :meth:`~sagemaker.tuner.HyperparameterTuner.fit` method launches. If not specified, a default job name is generated, based on the training image name and current timestamp. - strategy (str): Strategy to be used for hyperparameter estimations - (default: 'Bayesian'). + strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations. + More information about different strategies: + https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html. + Available options are: 'Bayesian', 'Random', 'Hyperband', + 'Grid' (default: 'Bayesian') strategy_config (dict): The configuration for a training job launched by a hyperparameter tuning job. completion_criteria_config (dict): The configuration for tuning job completion criteria. @@ -2080,21 +2076,19 @@ def _validate_dict_argument(cls, name, value, allowed_keys, require_same_keys=Fa return if not isinstance(value, dict): - raise ValueError( - "Argument '{}' must be a dictionary using {} as keys".format(name, allowed_keys) - ) + raise ValueError(f"Argument '{name}' must be a dictionary using {allowed_keys} as keys") value_keys = sorted(value.keys()) if require_same_keys: if value_keys != allowed_keys: raise ValueError( - "The keys of argument '{}' must be the same as {}".format(name, allowed_keys) + f"The keys of argument '{name}' must be the same as {allowed_keys}" ) else: if not set(value_keys).issubset(set(allowed_keys)): raise ValueError( - "The keys of argument '{}' must be a subset of {}".format(name, allowed_keys) + f"The keys of argument '{name}' must be a subset of {allowed_keys}" ) def _add_estimator(