Skip to content

Commit 2782f8c

Browse files
staubhpPayton Staub
andauthored
Revert "fix: Fix Pipeline variables related customer issues (#2959)" (#3041)
This reverts commit e464689. Co-authored-by: Payton Staub <[email protected]>
1 parent 5d8ebfc commit 2782f8c

20 files changed

+68
-430
lines changed

src/sagemaker/estimator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,9 +1879,7 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
18791879
if estimator.use_spot_instances:
18801880
if local_mode:
18811881
raise ValueError("Spot training is not supported in local mode.")
1882-
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
1883-
# which is parsed during the Pipeline execution runtime
1884-
train_args["use_spot_instances"] = estimator.use_spot_instances
1882+
train_args["use_spot_instances"] = True
18851883

18861884
if estimator.checkpoint_s3_uri:
18871885
if local_mode:

src/sagemaker/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from sagemaker.utils import unique_name_from_base
3838
from sagemaker.async_inference import AsyncInferenceConfig
3939
from sagemaker.predictor_async import AsyncPredictor
40-
from sagemaker.workflow.entities import PipelineVariable
4140

4241
LOGGER = logging.getLogger("sagemaker")
4342

@@ -444,7 +443,7 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
444443
)
445444

446445
if repack and self.model_data is not None and self.entry_point is not None:
447-
if isinstance(self.model_data, PipelineVariable):
446+
if isinstance(self.model_data, sagemaker.workflow.properties.Properties):
448447
# model is not yet there, defer repacking to later during pipeline execution
449448
return
450449

src/sagemaker/processing.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from sagemaker.session import Session
3737
from sagemaker.workflow.properties import Properties
3838
from sagemaker.workflow.parameters import Parameter
39-
from sagemaker.workflow.entities import Expression, PipelineVariable
39+
from sagemaker.workflow.entities import Expression
4040
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
4141
from sagemaker.apiutils._base_types import ApiObject
4242
from sagemaker.s3 import S3Uploader
@@ -233,12 +233,6 @@ def _normalize_args(
233233
kms_key (str): The ARN of the KMS key that is used to encrypt the
234234
user code file (default: None).
235235
"""
236-
if code and isinstance(code, PipelineVariable):
237-
raise ValueError(
238-
"code argument has to be a valid S3 URI or local file path "
239-
+ "rather than a pipeline variable"
240-
)
241-
242236
self._current_job_name = self._generate_current_job_name(job_name=job_name)
243237

244238
inputs_with_code = self._include_code_in_inputs(inputs, code, kms_key)

src/sagemaker/session.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -763,8 +763,6 @@ def _get_train_request( # noqa: C901
763763
train_request["EnableInterContainerTrafficEncryption"] = encrypt_inter_container_traffic
764764

765765
if use_spot_instances:
766-
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
767-
# which is parsed during the Pipeline execution runtime
768766
train_request["EnableManagedSpotTraining"] = use_spot_instances
769767

770768
if checkpoint_s3_uri:
@@ -2342,17 +2340,13 @@ def _map_training_config(
23422340
training_job_definition["VpcConfig"] = vpc_config
23432341

23442342
if enable_network_isolation:
2345-
training_job_definition["EnableNetworkIsolation"] = enable_network_isolation
2343+
training_job_definition["EnableNetworkIsolation"] = True
23462344

23472345
if encrypt_inter_container_traffic:
2348-
training_job_definition[
2349-
"EnableInterContainerTrafficEncryption"
2350-
] = encrypt_inter_container_traffic
2346+
training_job_definition["EnableInterContainerTrafficEncryption"] = True
23512347

23522348
if use_spot_instances:
2353-
# use_spot_instances may be a Pipeline ParameterBoolean object
2354-
# which is parsed during the Pipeline execution runtime
2355-
training_job_definition["EnableManagedSpotTraining"] = use_spot_instances
2349+
training_job_definition["EnableManagedSpotTraining"] = True
23562350

23572351
if checkpoint_s3_uri:
23582352
checkpoint_config = {"S3Uri": checkpoint_s3_uri}

src/sagemaker/tensorflow/model.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from sagemaker.deprecations import removed_kwargs
2222
from sagemaker.predictor import Predictor
2323
from sagemaker.serializers import JSONSerializer
24-
from sagemaker.workflow.entities import PipelineVariable
2524

2625

2726
class TensorFlowPredictor(Predictor):
@@ -331,9 +330,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
331330
image_uri = self._get_image_uri(instance_type, accelerator_type)
332331
env = self._get_container_env()
333332

334-
# If self.model_data is pipeline variable, model is not yet there.
335-
# So defer repacking to later during pipeline execution
336-
if self.entry_point and not isinstance(self.model_data, PipelineVariable):
333+
if self.entry_point:
337334
key_prefix = sagemaker.fw_utils.model_code_key_prefix(
338335
self.key_prefix, self.name, image_uri
339336
)

src/sagemaker/tuner.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
ParameterRange,
4040
)
4141
from sagemaker.workflow.entities import PipelineVariable
42+
from sagemaker.workflow.parameters import Parameter as PipelineParameter
43+
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
44+
from sagemaker.workflow.functions import Join as PipelineJoin
4245

4346
from sagemaker.session import Session
4447
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base
@@ -61,6 +64,18 @@
6164
logger = logging.getLogger(__name__)
6265

6366

67+
def is_pipeline_parameters(value):
68+
"""Determine if a value is a pipeline parameter or function representation
69+
70+
Args:
71+
value (float or int): The value to be verified.
72+
73+
Returns:
74+
bool: True if it is, False otherwise.
75+
"""
76+
return isinstance(value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
77+
78+
6479
class WarmStartTypes(Enum):
6580
"""Warm Start Configuration type.
6681

src/sagemaker/workflow/_repack_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None):
3939
4040
Args:
4141
inference_script (str): The path to the custom entry point.
42-
model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive.
42+
model_archive (str): The name of the model TAR archive.
4343
dependencies (str): A space-delimited string of paths to custom dependencies.
4444
source_dir (str): The path to a custom source directory.
4545
"""
4646

4747
# the data directory contains a model archive generated by a previous training job
4848
data_directory = "/opt/ml/input/data/training"
49-
model_path = os.path.join(data_directory, model_archive.split("/")[-1])
49+
model_path = os.path.join(data_directory, model_archive)
5050

5151
# create a temporary directory
5252
with tempfile.TemporaryDirectory() as tmp:

src/sagemaker/workflow/_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ def __init__(
134134
self._model_data = model_data
135135
self.sagemaker_session = sagemaker_session
136136
self.role = role
137+
if isinstance(model_data, Properties):
138+
self._model_prefix = model_data
139+
self._model_archive = "model.tar.gz"
140+
else:
141+
self._model_prefix = "/".join(self._model_data.split("/")[:-1])
142+
self._model_archive = self._model_data.split("/")[-1]
137143
self._entry_point = entry_point
138144
self._entry_point_basename = os.path.basename(self._entry_point)
139145
self._source_dir = source_dir
@@ -155,7 +161,7 @@ def __init__(
155161
role=self.role,
156162
hyperparameters={
157163
"inference_script": self._entry_point_basename,
158-
"model_archive": self._model_data,
164+
"model_archive": self._model_archive,
159165
"dependencies": dependencies_hyperparameter,
160166
"source_dir": self._source_dir,
161167
},
@@ -164,7 +170,7 @@ def __init__(
164170
**kwargs,
165171
)
166172
repacker.disable_profiler = True
167-
inputs = TrainingInput(self._model_data)
173+
inputs = TrainingInput(self._model_prefix)
168174

169175
# super!
170176
super(_RepackModelStep, self).__init__(

src/sagemaker/workflow/airflow.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,7 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
184184
train_config["VpcConfig"] = job_config["vpc_config"]
185185

186186
if estimator.use_spot_instances:
187-
# estimator.use_spot_instances may be a Pipeline ParameterBoolean object
188-
# which is parsed during the Pipeline execution runtime
189-
train_config["EnableManagedSpotTraining"] = estimator.use_spot_instances
187+
train_config["EnableManagedSpotTraining"] = True
190188

191189
if estimator.hyperparameters() is not None:
192190
hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}

src/sagemaker/workflow/clarify_check_step.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
from sagemaker.model_monitor.model_monitoring import _MODEL_MONITOR_S3_PATH
3838
from sagemaker.processing import ProcessingInput, ProcessingOutput, ProcessingJob
3939
from sagemaker.utils import name_from_base
40-
from sagemaker.workflow import PipelineNonPrimitiveInputTypes
41-
from sagemaker.workflow.entities import RequestType, PipelineVariable
40+
from sagemaker.workflow import PipelineNonPrimitiveInputTypes, ExecutionVariable, Parameter
41+
from sagemaker.workflow.entities import RequestType, Expression
4242
from sagemaker.workflow.properties import Properties
4343
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
4444
from sagemaker.workflow.check_job_config import CheckJobConfig
@@ -194,15 +194,17 @@ def __init__(
194194
)
195195

196196
if isinstance(
197-
clarify_check_config.data_config.s3_analysis_config_output_path, PipelineVariable
197+
clarify_check_config.data_config.s3_analysis_config_output_path,
198+
(ExecutionVariable, Expression, Parameter, Properties),
198199
):
199200
raise RuntimeError(
200201
"s3_analysis_config_output_path cannot be of type "
201202
+ "ExecutionVariable/Expression/Parameter/Properties"
202203
)
203204

204205
if not clarify_check_config.data_config.s3_analysis_config_output_path and isinstance(
205-
clarify_check_config.data_config.s3_output_path, PipelineVariable
206+
clarify_check_config.data_config.s3_output_path,
207+
(ExecutionVariable, Expression, Parameter, Properties),
206208
):
207209
raise RuntimeError(
208210
"`s3_output_path` cannot be of type ExecutionVariable/Expression/Parameter"

0 commit comments

Comments
 (0)