6666 ConditionIn ,
6767 ConditionLessThanOrEqualTo ,
6868)
69- from sagemaker .workflow .condition_step import ConditionStep , JsonGet
69+ from sagemaker .workflow .condition_step import ConditionStep
7070from sagemaker .workflow .callback_step import CallbackStep , CallbackOutput , CallbackOutputTypeEnum
7171from sagemaker .workflow .lambda_step import LambdaStep , LambdaOutput , LambdaOutputTypeEnum
72- from sagemaker .workflow .properties import PropertyFile
7372from sagemaker .wrangler .processing import DataWranglerProcessor
7473from sagemaker .dataset_definition .inputs import DatasetDefinition , AthenaDatasetDefinition
7574from sagemaker .workflow .execution_variables import ExecutionVariables
76- from sagemaker .workflow .functions import Join
75+ from sagemaker .workflow .functions import Join , JsonGet
7776from sagemaker .wrangler .ingestion import generate_data_ingestion_flow_from_s3_input
7877from sagemaker .workflow .parameters import (
7978 ParameterInteger ,
8786 TuningStep ,
8887 TransformStep ,
8988 TransformInput ,
89+ PropertyFile ,
9090)
9191from sagemaker .workflow .step_collections import RegisterModel
9292from sagemaker .workflow .pipeline import Pipeline
@@ -137,7 +137,7 @@ def feature_store_session(sagemaker_session):
137137
138138@pytest .fixture
139139def pipeline_name ():
140- return f"my-pipeline-{ int (time .time () * 10 ** 7 )} "
140+ return f"my-pipeline-{ int (time .time () * 10 ** 7 )} "
141141
142142
143143@pytest .fixture
@@ -1371,6 +1371,8 @@ def test_tuning_multi_algos(
13711371 cpu_instance_type ,
13721372 pipeline_name ,
13731373 region_name ,
1374+ script_dir ,
1375+ athena_dataset_definition ,
13741376):
13751377 base_dir = os .path .join (DATA_DIR , "pytorch_mnist" )
13761378 entry_point = os .path .join (base_dir , "mnist.py" )
@@ -1382,6 +1384,42 @@ def test_tuning_multi_algos(
13821384 instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
13831385 instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
13841386
1387+ input_data = f"s3://sagemaker-sample-data-{ region_name } /processing/census/census-income.csv"
1388+
1389+ sklearn_processor = SKLearnProcessor (
1390+ framework_version = "0.20.0" ,
1391+ instance_type = instance_type ,
1392+ instance_count = instance_count ,
1393+ base_job_name = "test-sklearn" ,
1394+ sagemaker_session = sagemaker_session ,
1395+ role = role ,
1396+ )
1397+
1398+ property_file = PropertyFile (
1399+ name = "DataAttributes" , output_name = "attributes" , path = "attributes.json"
1400+ )
1401+
1402+ step_process = ProcessingStep (
1403+ name = "my-process" ,
1404+ display_name = "ProcessingStep" ,
1405+ description = "description for Processing step" ,
1406+ processor = sklearn_processor ,
1407+ inputs = [
1408+ ProcessingInput (source = input_data , destination = "/opt/ml/processing/input" ),
1409+ ProcessingInput (dataset_definition = athena_dataset_definition ),
1410+ ],
1411+ outputs = [
1412+ ProcessingOutput (output_name = "train_data" , source = "/opt/ml/processing/train" ),
1413+ ProcessingOutput (output_name = "attributes" , source = "/opt/ml/processing/attributes.json" ),
1414+ ],
1415+ property_files = [property_file ],
1416+ code = os .path .join (script_dir , "preprocessing.py" ),
1417+ )
1418+
1419+ static_hp_1 = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
1420+ json_get_hp = JsonGet (
1421+ step_name = step_process .name , property_file = property_file , json_path = "train_size"
1422+ )
13851423 pytorch_estimator = PyTorch (
13861424 entry_point = entry_point ,
13871425 role = role ,
@@ -1392,10 +1430,11 @@ def test_tuning_multi_algos(
13921430 sagemaker_session = sagemaker_session ,
13931431 enable_sagemaker_metrics = True ,
13941432 max_retry_attempts = 3 ,
1433+ hyperparameters = {"static-hp" : static_hp_1 , "train_size" : json_get_hp },
13951434 )
13961435
13971436 min_batch_size = ParameterString (name = "MinBatchSize" , default_value = "64" )
1398- max_batch_size = ParameterString ( name = "MaxBatchSize" , default_value = "128" )
1437+ max_batch_size = json_get_hp
13991438
14001439 tuner = HyperparameterTuner .create (
14011440 estimator_dict = {
@@ -1415,6 +1454,7 @@ def test_tuning_multi_algos(
14151454 "estimator-2" : [{"Name" : "test:acc" , "Regex" : "Overall test accuracy: (.*?);" }],
14161455 },
14171456 )
1457+
14181458 inputs = {
14191459 "estimator-1" : TrainingInput (s3_data = input_path ),
14201460 "estimator-2" : TrainingInput (s3_data = input_path ),
@@ -1429,7 +1469,7 @@ def test_tuning_multi_algos(
14291469 pipeline = Pipeline (
14301470 name = pipeline_name ,
14311471 parameters = [instance_count , instance_type , min_batch_size , max_batch_size ],
1432- steps = [step_tune ],
1472+ steps = [step_process , step_tune ],
14331473 sagemaker_session = sagemaker_session ,
14341474 )
14351475
0 commit comments