1616import json
1717
1818from copy import deepcopy
19- from typing import Any , Dict , List , Sequence , Union
19+ from typing import Any , Dict , List , Sequence , Union , Optional
2020
2121import attr
2222import botocore
3030 Expression ,
3131 RequestType ,
3232)
33+ from sagemaker .workflow .execution_variables import ExecutionVariables
3334from sagemaker .workflow .parameters import Parameter
35+ from sagemaker .workflow .pipeline_experiment_config import PipelineExperimentConfig
3436from sagemaker .workflow .properties import Properties
3537from sagemaker .workflow .steps import Step
3638from sagemaker .workflow .step_collections import StepCollection
@@ -44,6 +46,12 @@ class Pipeline(Entity):
4446 Attributes:
4547 name (str): The name of the pipeline.
4648 parameters (Sequence[Parameters]): The list of the parameters.
49+ pipeline_experiment_config (Optional[PipelineExperimentConfig]): If set,
50+ the workflow will attempt to create an experiment and trial before
51+ executing the steps. Creation will be skipped if an experiment or a trial with
52+ the same name already exists. By default, pipeline name is used as
53+ experiment name and execution id is used as the trial name.
54+ If set to None, no experiment or trial will be created automatically.
4755 steps (Sequence[Steps]): The list of the non-conditional steps associated with the pipeline.
4856 Any steps that are within the
4957 `if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a
@@ -57,6 +65,11 @@ class Pipeline(Entity):
5765
5866 name : str = attr .ib (factory = str )
5967 parameters : Sequence [Parameter ] = attr .ib (factory = list )
68+ pipeline_experiment_config : Optional [PipelineExperimentConfig ] = attr .ib (
69+ default = PipelineExperimentConfig (
70+ ExecutionVariables .PIPELINE_NAME , ExecutionVariables .PIPELINE_EXECUTION_ID
71+ )
72+ )
6073 steps : Sequence [Union [Step , StepCollection ]] = attr .ib (factory = list )
6174 sagemaker_session : Session = attr .ib (factory = Session )
6275
@@ -69,22 +82,23 @@ def to_request(self) -> RequestType:
6982 "Version" : self ._version ,
7083 "Metadata" : self ._metadata ,
7184 "Parameters" : list_to_request (self .parameters ),
85+ "PipelineExperimentConfig" : self .pipeline_experiment_config .to_request ()
86+ if self .pipeline_experiment_config is not None
87+ else None ,
7288 "Steps" : list_to_request (self .steps ),
7389 }
7490
7591 def create (
7692 self ,
7793 role_arn : str ,
7894 description : str = None ,
79- experiment_name : str = None ,
8095 tags : List [Dict [str , str ]] = None ,
8196 ) -> Dict [str , Any ]:
8297 """Creates a Pipeline in the Pipelines service.
8398
8499 Args:
85100 role_arn (str): The role arn that is assumed by the pipeline to create step artifacts.
86101 description (str): A description of the pipeline.
87- experiment_name (str): The name of the experiment.
88102 tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
89103 tags.
90104
@@ -96,7 +110,6 @@ def create(
96110 kwargs = self ._create_args (role_arn , description )
97111 update_args (
98112 kwargs ,
99- ExperimentName = experiment_name ,
100113 Tags = tags ,
101114 )
102115 return self .sagemaker_session .sagemaker_client .create_pipeline (** kwargs )
@@ -106,7 +119,7 @@ def _create_args(self, role_arn: str, description: str):
106119
107120 Args:
108121 role_arn (str): The role arn that is assumed by pipelines to create step artifacts.
109- pipeline_description (str): A description of the pipeline.
122+ description (str): A description of the pipeline.
110123
111124 Returns:
112125 A keyword argument dict for calling create_pipeline.
@@ -147,23 +160,21 @@ def upsert(
147160 self ,
148161 role_arn : str ,
149162 description : str = None ,
150- experiment_name : str = None ,
151163 tags : List [Dict [str , str ]] = None ,
152164 ) -> Dict [str , Any ]:
153165 """Creates a pipeline or updates it, if it already exists.
154166
155167 Args:
156168 role_arn (str): The role arn that is assumed by workflow to create step artifacts.
157- pipeline_description (str): A description of the pipeline.
158- experiment_name (str): The name of the experiment.
169+ description (str): A description of the pipeline.
159170 tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as
160171 tags.
161172
162173 Returns:
163174 response dict from service
164175 """
165176 try :
166- response = self .create (role_arn , description , experiment_name , tags )
177+ response = self .create (role_arn , description , tags )
167178 except ClientError as e :
168179 error = e .response ["Error" ]
169180 if (
@@ -224,6 +235,9 @@ def start(
224235 def definition (self ) -> str :
225236 """Converts a request structure to string representation for workflow service calls."""
226237 request_dict = self .to_request ()
238+ request_dict ["PipelineExperimentConfig" ] = interpolate (
239+ request_dict ["PipelineExperimentConfig" ]
240+ )
227241 request_dict ["Steps" ] = interpolate (request_dict ["Steps" ])
228242
229243 return json .dumps (request_dict )
0 commit comments