1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
1313from __future__ import absolute_import
14+ import json
1415
1516import pytest
1617from mock import Mock , MagicMock
17- from sagemaker .workflow .conditions import ConditionEquals
18- from sagemaker .workflow .parameters import ParameterInteger
18+ from sagemaker .workflow .conditions import (
19+ ConditionEquals ,
20+ ConditionGreaterThan ,
21+ ConditionGreaterThanOrEqualTo ,
22+ ConditionIn ,
23+ ConditionLessThan ,
24+ ConditionLessThanOrEqualTo ,
25+ ConditionNot ,
26+ ConditionOr ,
27+ )
28+ from sagemaker .workflow .execution_variables import ExecutionVariables
29+ from sagemaker .workflow .parameters import ParameterInteger , ParameterString
1930from sagemaker .workflow .condition_step import ConditionStep
2031from sagemaker .workflow .pipeline import Pipeline , PipelineGraph
32+ from sagemaker .workflow .properties import Properties
2133from tests .unit .sagemaker .workflow .helpers import CustomStep , ordered
2234
2335
@@ -56,7 +68,7 @@ def test_condition_step():
5668 "Conditions" : [
5769 {
5870 "Type" : "Equals" ,
59- "LeftValue" : { "Get" : "Parameters.MyInt" } ,
71+ "LeftValue" : param ,
6072 "RightValue" : 1 ,
6173 },
6274 ],
@@ -79,6 +91,147 @@ def test_condition_step():
7991 assert cond_step .properties .Outcome .expr == {"Get" : "Steps.MyConditionStep.Outcome" }
8092
8193
94+ def test_pipeline_condition_step_interpolated (sagemaker_session ):
95+ param1 = ParameterInteger (name = "MyInt1" )
96+ param2 = ParameterInteger (name = "MyInt2" )
97+ param3 = ParameterString (name = "MyStr" )
98+ var = ExecutionVariables .START_DATETIME
99+ prop = Properties ("foo" )
100+
101+ cond_eq = ConditionEquals (left = param1 , right = param2 )
102+ cond_gt = ConditionGreaterThan (left = var , right = "2020-12-01" )
103+ cond_gte = ConditionGreaterThanOrEqualTo (left = var , right = param3 )
104+ cond_lt = ConditionLessThan (left = var , right = "2020-12-01" )
105+ cond_lte = ConditionLessThanOrEqualTo (left = var , right = param3 )
106+ cond_in = ConditionIn (value = param3 , in_values = ["abc" , "def" ])
107+ cond_in_mixed = ConditionIn (value = param3 , in_values = ["abc" , prop , var ])
108+ cond_not_eq = ConditionNot (expression = cond_eq )
109+ cond_not_in = ConditionNot (expression = cond_in )
110+ cond_or = ConditionOr (conditions = [cond_gt , cond_in ])
111+
112+ step1 = CustomStep (name = "MyStep1" )
113+ step2 = CustomStep (name = "MyStep2" )
114+ cond_step = ConditionStep (
115+ name = "MyConditionStep" ,
116+ conditions = [
117+ cond_eq ,
118+ cond_gt ,
119+ cond_gte ,
120+ cond_lt ,
121+ cond_lte ,
122+ cond_in ,
123+ cond_in_mixed ,
124+ cond_not_eq ,
125+ cond_not_in ,
126+ cond_or ,
127+ ],
128+ if_steps = [step1 ],
129+ else_steps = [step2 ],
130+ )
131+
132+ pipeline = Pipeline (
133+ name = "MyPipeline" ,
134+ parameters = [param1 , param2 , param3 ],
135+ steps = [cond_step ],
136+ sagemaker_session = sagemaker_session ,
137+ )
138+ assert json .loads (pipeline .definition ()) == {
139+ "Version" : "2020-12-01" ,
140+ "Metadata" : {},
141+ "Parameters" : [
142+ {"Name" : "MyInt1" , "Type" : "Integer" },
143+ {"Name" : "MyInt2" , "Type" : "Integer" },
144+ {"Name" : "MyStr" , "Type" : "String" },
145+ ],
146+ "PipelineExperimentConfig" : {
147+ "ExperimentName" : {"Get" : "Execution.PipelineName" },
148+ "TrialName" : {"Get" : "Execution.PipelineExecutionId" },
149+ },
150+ "Steps" : [
151+ {
152+ "Name" : "MyConditionStep" ,
153+ "Type" : "Condition" ,
154+ "Arguments" : {
155+ "Conditions" : [
156+ {
157+ "Type" : "Equals" ,
158+ "LeftValue" : {"Get" : "Parameters.MyInt1" },
159+ "RightValue" : {"Get" : "Parameters.MyInt2" },
160+ },
161+ {
162+ "Type" : "GreaterThan" ,
163+ "LeftValue" : {"Get" : "Execution.StartDateTime" },
164+ "RightValue" : "2020-12-01" ,
165+ },
166+ {
167+ "Type" : "GreaterThanOrEqualTo" ,
168+ "LeftValue" : {"Get" : "Execution.StartDateTime" },
169+ "RightValue" : {"Get" : "Parameters.MyStr" },
170+ },
171+ {
172+ "Type" : "LessThan" ,
173+ "LeftValue" : {"Get" : "Execution.StartDateTime" },
174+ "RightValue" : "2020-12-01" ,
175+ },
176+ {
177+ "Type" : "LessThanOrEqualTo" ,
178+ "LeftValue" : {"Get" : "Execution.StartDateTime" },
179+ "RightValue" : {"Get" : "Parameters.MyStr" },
180+ },
181+ {
182+ "Type" : "In" ,
183+ "QueryValue" : {"Get" : "Parameters.MyStr" },
184+ "Values" : ["abc" , "def" ],
185+ },
186+ {
187+ "Type" : "In" ,
188+ "QueryValue" : {"Get" : "Parameters.MyStr" },
189+ "Values" : [
190+ "abc" ,
191+ {"Get" : "Steps.foo" },
192+ {"Get" : "Execution.StartDateTime" },
193+ ],
194+ },
195+ {
196+ "Type" : "Not" ,
197+ "Expression" : {
198+ "Type" : "Equals" ,
199+ "LeftValue" : {"Get" : "Parameters.MyInt1" },
200+ "RightValue" : {"Get" : "Parameters.MyInt2" },
201+ },
202+ },
203+ {
204+ "Type" : "Not" ,
205+ "Expression" : {
206+ "Type" : "In" ,
207+ "QueryValue" : {"Get" : "Parameters.MyStr" },
208+ "Values" : ["abc" , "def" ],
209+ },
210+ },
211+ {
212+ "Type" : "Or" ,
213+ "Conditions" : [
214+ {
215+ "Type" : "GreaterThan" ,
216+ "LeftValue" : {"Get" : "Execution.StartDateTime" },
217+ "RightValue" : "2020-12-01" ,
218+ },
219+ {
220+ "Type" : "In" ,
221+ "QueryValue" : {"Get" : "Parameters.MyStr" },
222+ "Values" : ["abc" , "def" ],
223+ },
224+ ],
225+ },
226+ ],
227+ "IfSteps" : [{"Name" : "MyStep1" , "Type" : "Training" , "Arguments" : {}}],
228+ "ElseSteps" : [{"Name" : "MyStep2" , "Type" : "Training" , "Arguments" : {}}],
229+ },
230+ }
231+ ],
232+ }
233+
234+
82235def test_pipeline (sagemaker_session ):
83236 param = ParameterInteger (name = "MyInt" , default_value = 2 )
84237 cond = ConditionEquals (left = param , right = 1 )
0 commit comments