|
18 | 18 | import os |
19 | 19 | import uuid |
20 | 20 | from abc import ABCMeta, abstractmethod |
21 | | -from typing import Any, Dict |
| 21 | +from typing import Any, Dict, Union, Optional, List |
22 | 22 |
|
23 | 23 | from six import string_types, with_metaclass |
24 | 24 | from six.moves.urllib.parse import urlparse |
|
36 | 36 | TensorBoardOutputConfig, |
37 | 37 | get_default_profiler_rule, |
38 | 38 | get_rule_container_image_uri, |
| 39 | + RuleBase, |
39 | 40 | ) |
40 | 41 | from sagemaker.deprecations import removed_function, removed_kwargs, renamed_kwargs |
41 | 42 | from sagemaker.fw_utils import ( |
|
46 | 47 | tar_and_upload_dir, |
47 | 48 | validate_source_dir, |
48 | 49 | ) |
49 | | -from sagemaker.inputs import TrainingInput |
| 50 | +from sagemaker.inputs import TrainingInput, FileSystemInput |
50 | 51 | from sagemaker.job import _Job |
51 | 52 | from sagemaker.jumpstart.utils import ( |
52 | 53 | add_jumpstart_tags, |
|
75 | 76 | name_from_base, |
76 | 77 | ) |
77 | 78 | from sagemaker.workflow import is_pipeline_variable |
| 79 | +from sagemaker.workflow.entities import PipelineVariable |
78 | 80 | from sagemaker.workflow.pipeline_context import ( |
79 | 81 | PipelineSession, |
80 | 82 | runnable_by_pipeline, |
@@ -105,44 +107,44 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man |
105 | 107 |
|
106 | 108 | def __init__( |
107 | 109 | self, |
108 | | - role, |
109 | | - instance_count=None, |
110 | | - instance_type=None, |
111 | | - volume_size=30, |
112 | | - volume_kms_key=None, |
113 | | - max_run=24 * 60 * 60, |
114 | | - input_mode="File", |
115 | | - output_path=None, |
116 | | - output_kms_key=None, |
117 | | - base_job_name=None, |
118 | | - sagemaker_session=None, |
119 | | - tags=None, |
120 | | - subnets=None, |
121 | | - security_group_ids=None, |
122 | | - model_uri=None, |
123 | | - model_channel_name="model", |
124 | | - metric_definitions=None, |
125 | | - encrypt_inter_container_traffic=False, |
126 | | - use_spot_instances=False, |
127 | | - max_wait=None, |
128 | | - checkpoint_s3_uri=None, |
129 | | - checkpoint_local_path=None, |
130 | | - rules=None, |
131 | | - debugger_hook_config=None, |
132 | | - tensorboard_output_config=None, |
133 | | - enable_sagemaker_metrics=None, |
134 | | - enable_network_isolation=False, |
135 | | - profiler_config=None, |
136 | | - disable_profiler=False, |
137 | | - environment=None, |
138 | | - max_retry_attempts=None, |
139 | | - source_dir=None, |
140 | | - git_config=None, |
141 | | - hyperparameters=None, |
142 | | - container_log_level=logging.INFO, |
143 | | - code_location=None, |
144 | | - entry_point=None, |
145 | | - dependencies=None, |
| 110 | + role: str, |
| 111 | + instance_count: Optional[Union[int, PipelineVariable]] = None, |
| 112 | + instance_type: Optional[Union[str, PipelineVariable]] = None, |
| 113 | + volume_size: Union[int, PipelineVariable] = 30, |
| 114 | + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 115 | + max_run: Union[int, PipelineVariable] = 24 * 60 * 60, |
| 116 | + input_mode: Union[str, PipelineVariable] = "File", |
| 117 | + output_path: Optional[Union[str, PipelineVariable]] = None, |
| 118 | + output_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 119 | + base_job_name: Optional[str] = None, |
| 120 | + sagemaker_session: Optional[Session] = None, |
| 121 | + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, |
| 122 | + subnets: Optional[List[Union[str, PipelineVariable]]] = None, |
| 123 | + security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, |
| 124 | + model_uri: Optional[str] = None, |
| 125 | + model_channel_name: Union[str, PipelineVariable] = "model", |
| 126 | + metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, |
| 127 | + encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False, |
| 128 | + use_spot_instances: Union[bool, PipelineVariable] = False, |
| 129 | + max_wait: Optional[Union[int, PipelineVariable]] = None, |
| 130 | + checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, |
| 131 | + checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, |
| 132 | + rules: Optional[List[RuleBase]] = None, |
| 133 | + debugger_hook_config: Optional[Union[bool, DebuggerHookConfig]] = None, |
| 134 | + tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, |
| 135 | + enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, |
| 136 | + enable_network_isolation: Union[bool, PipelineVariable] = False, |
| 137 | + profiler_config: Optional[ProfilerConfig] = None, |
| 138 | + disable_profiler: bool = False, |
| 139 | + environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 140 | + max_retry_attempts: Optional[Union[int, PipelineVariable]] = None, |
| 141 | + source_dir: Optional[str] = None, |
| 142 | + git_config: Optional[Dict[str, str]] = None, |
| 143 | + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 144 | + container_log_level: Union[int, PipelineVariable] = logging.INFO, |
| 145 | + code_location: Optional[str] = None, |
| 146 | + entry_point: Optional[str] = None, |
| 147 | + dependencies: Optional[List[Union[str]]] = None, |
146 | 148 | **kwargs, |
147 | 149 | ): |
148 | 150 | """Initialize an ``EstimatorBase`` instance. |
@@ -922,7 +924,14 @@ def latest_job_profiler_artifacts_path(self): |
922 | 924 | return None |
923 | 925 |
|
924 | 926 | @runnable_by_pipeline |
925 | | - def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_config=None): |
| 927 | + def fit( |
| 928 | + self, |
| 929 | + inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None, |
| 930 | + wait: bool = True, |
| 931 | + logs: str = "All", |
| 932 | + job_name: Optional[str] = None, |
| 933 | + experiment_config: Optional[Dict[str, str]] = None, |
| 934 | + ): |
926 | 935 | """Train a model using the input training dataset. |
927 | 936 |
|
928 | 937 | The API calls the Amazon SageMaker CreateTrainingJob API to start |
@@ -1870,16 +1879,22 @@ def _get_train_args(cls, estimator, inputs, experiment_config): |
1870 | 1879 | ) |
1871 | 1880 | train_args["input_mode"] = inputs.config["InputMode"] |
1872 | 1881 |
|
| 1882 | + # enable_network_isolation may be a pipeline variable place holder object |
| 1883 | + # which is parsed in execution time |
1873 | 1884 | if estimator.enable_network_isolation(): |
1874 | | - train_args["enable_network_isolation"] = True |
| 1885 | + train_args["enable_network_isolation"] = estimator.enable_network_isolation() |
1875 | 1886 |
|
1876 | 1887 | if estimator.max_retry_attempts is not None: |
1877 | 1888 | train_args["retry_strategy"] = {"MaximumRetryAttempts": estimator.max_retry_attempts} |
1878 | 1889 | else: |
1879 | 1890 | train_args["retry_strategy"] = None |
1880 | 1891 |
|
| 1892 | + # encrypt_inter_container_traffic may be a pipeline variable place holder object |
| 1893 | + # which is parsed in execution time |
1881 | 1894 | if estimator.encrypt_inter_container_traffic: |
1882 | | - train_args["encrypt_inter_container_traffic"] = True |
| 1895 | + train_args[ |
| 1896 | + "encrypt_inter_container_traffic" |
| 1897 | + ] = estimator.encrypt_inter_container_traffic |
1883 | 1898 |
|
1884 | 1899 | if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator): |
1885 | 1900 | train_args["algorithm_arn"] = estimator.algorithm_arn |
@@ -2025,45 +2040,45 @@ class Estimator(EstimatorBase): |
2025 | 2040 |
|
2026 | 2041 | def __init__( |
2027 | 2042 | self, |
2028 | | - image_uri, |
2029 | | - role, |
2030 | | - instance_count=None, |
2031 | | - instance_type=None, |
2032 | | - volume_size=30, |
2033 | | - volume_kms_key=None, |
2034 | | - max_run=24 * 60 * 60, |
2035 | | - input_mode="File", |
2036 | | - output_path=None, |
2037 | | - output_kms_key=None, |
2038 | | - base_job_name=None, |
2039 | | - sagemaker_session=None, |
2040 | | - hyperparameters=None, |
2041 | | - tags=None, |
2042 | | - subnets=None, |
2043 | | - security_group_ids=None, |
2044 | | - model_uri=None, |
2045 | | - model_channel_name="model", |
2046 | | - metric_definitions=None, |
2047 | | - encrypt_inter_container_traffic=False, |
2048 | | - use_spot_instances=False, |
2049 | | - max_wait=None, |
2050 | | - checkpoint_s3_uri=None, |
2051 | | - checkpoint_local_path=None, |
2052 | | - enable_network_isolation=False, |
2053 | | - rules=None, |
2054 | | - debugger_hook_config=None, |
2055 | | - tensorboard_output_config=None, |
2056 | | - enable_sagemaker_metrics=None, |
2057 | | - profiler_config=None, |
2058 | | - disable_profiler=False, |
2059 | | - environment=None, |
2060 | | - max_retry_attempts=None, |
2061 | | - source_dir=None, |
2062 | | - git_config=None, |
2063 | | - container_log_level=logging.INFO, |
2064 | | - code_location=None, |
2065 | | - entry_point=None, |
2066 | | - dependencies=None, |
| 2043 | + image_uri: Union[str, PipelineVariable], |
| 2044 | + role: str, |
| 2045 | + instance_count: Optional[Union[int, PipelineVariable]] = None, |
| 2046 | + instance_type: Optional[Union[str, PipelineVariable]] = None, |
| 2047 | + volume_size: Union[int, PipelineVariable] = 30, |
| 2048 | + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 2049 | + max_run: Union[int, PipelineVariable] = 24 * 60 * 60, |
| 2050 | + input_mode: Union[str, PipelineVariable] = "File", |
| 2051 | + output_path: Optional[Union[str, PipelineVariable]] = None, |
| 2052 | + output_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 2053 | + base_job_name: Optional[str] = None, |
| 2054 | + sagemaker_session: Optional[Session] = None, |
| 2055 | + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 2056 | + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, |
| 2057 | + subnets: Optional[List[Union[str, PipelineVariable]]] = None, |
| 2058 | + security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, |
| 2059 | + model_uri: Optional[str] = None, |
| 2060 | + model_channel_name: Union[str, PipelineVariable] = "model", |
| 2061 | + metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, |
| 2062 | + encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False, |
| 2063 | + use_spot_instances: Union[bool, PipelineVariable] = False, |
| 2064 | + max_wait: Optional[Union[int, PipelineVariable]] = None, |
| 2065 | + checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, |
| 2066 | + checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, |
| 2067 | + enable_network_isolation: Union[bool, PipelineVariable] = False, |
| 2068 | + rules: Optional[List[RuleBase]] = None, |
| 2069 | + debugger_hook_config: Optional[Union[DebuggerHookConfig, bool]] = None, |
| 2070 | + tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, |
| 2071 | + enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, |
| 2072 | + profiler_config: Optional[ProfilerConfig] = None, |
| 2073 | + disable_profiler: bool = False, |
| 2074 | + environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 2075 | + max_retry_attempts: Optional[Union[int, PipelineVariable]] = None, |
| 2076 | + source_dir: Optional[str] = None, |
| 2077 | + git_config: Optional[Dict[str, str]] = None, |
| 2078 | + container_log_level: Union[int, PipelineVariable] = logging.INFO, |
| 2079 | + code_location: Optional[str] = None, |
| 2080 | + entry_point: Optional[str] = None, |
| 2081 | + dependencies: Optional[List[str]] = None, |
2067 | 2082 | **kwargs, |
2068 | 2083 | ): |
2069 | 2084 | """Initialize an ``Estimator`` instance. |
@@ -2488,18 +2503,18 @@ class Framework(EstimatorBase): |
2488 | 2503 |
|
2489 | 2504 | def __init__( |
2490 | 2505 | self, |
2491 | | - entry_point, |
2492 | | - source_dir=None, |
2493 | | - hyperparameters=None, |
2494 | | - container_log_level=logging.INFO, |
2495 | | - code_location=None, |
2496 | | - image_uri=None, |
2497 | | - dependencies=None, |
2498 | | - enable_network_isolation=False, |
2499 | | - git_config=None, |
2500 | | - checkpoint_s3_uri=None, |
2501 | | - checkpoint_local_path=None, |
2502 | | - enable_sagemaker_metrics=None, |
| 2506 | + entry_point: str, |
| 2507 | + source_dir: Optional[str] = None, |
| 2508 | + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 2509 | + container_log_level: Union[int, PipelineVariable] = logging.INFO, |
| 2510 | + code_location: Optional[str] = None, |
| 2511 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 2512 | + dependencies: Optional[List[str]] = None, |
| 2513 | + enable_network_isolation: Union[bool, PipelineVariable] = False, |
| 2514 | + git_config: Optional[Dict[str, str]] = None, |
| 2515 | + checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, |
| 2516 | + checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, |
| 2517 | + enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, |
2503 | 2518 | **kwargs, |
2504 | 2519 | ): |
2505 | 2520 | """Base class initializer. |
|
0 commit comments