2727 ScriptProcessor ,
2828)
2929from sagemaker .sklearn .processing import SKLearnProcessor
30- from sagemaker .workflow .conditions import ConditionLessThanOrEqualTo
30+ from sagemaker .workflow .conditions import ConditionGreaterThanOrEqualTo
3131from sagemaker .workflow .condition_step import (
3232 ConditionStep ,
3333)
4747from sagemaker .workflow .step_collections import RegisterModel
4848
4949from botocore .exceptions import ClientError
50-
50+ import boto3
5151
5252BASE_DIR = os .path .dirname (os .path .realpath (__file__ ))
5353
@@ -76,72 +76,7 @@ def get_session(region, default_bucket):
7676 default_bucket = default_bucket ,
7777 )
7878
79- def resolve_ecr_uri_from_image_versions (sagemaker_session , image_versions , image_name ):
80- """ Gets ECR URI from image versions
81- Args:
82- sagemaker_session: boto3 session for sagemaker client
83- image_versions: list of the image versions
84- image_name: Name of the image
85-
86- Returns:
87- ECR URI of the image version
88- """
8979
90- #Fetch image details to get the Base Image URI
91- for image_version in image_versions :
92- if image_version ['ImageVersionStatus' ] == 'CREATED' :
93- image_arn = image_version ["ImageVersionArn" ]
94- version = image_version ["Version" ]
95- logger .info (f"Identified the latest image version: { image_arn } " )
96- response = sagemaker_session .sagemaker_client .describe_image_version (
97- ImageName = image_name ,
98- Version = version
99- )
100- return response ['ContainerImage' ]
101- return None
102-
103- def resolve_ecr_uri (sagemaker_session , image_arn ):
104- """Gets the ECR URI from the image name
105-
106- Args:
107- sagemaker_session: boto3 session for sagemaker client
108- image_name: name of the image
109-
110- Returns:
111- ECR URI of the latest image version
112- """
113-
114- # Fetching image name from image_arn (^arn:aws(-[\w]+)*:sagemaker:.+:[0-9]{12}:image/[a-z0-9]([-.]?[a-z0-9])*$)
115- image_name = image_arn .partition ("image/" )[2 ]
116- try :
117- # Fetch the image versions
118- next_token = ''
119- while True :
120- response = sagemaker_session .sagemaker_client .list_image_versions (
121- ImageName = image_name ,
122- MaxResults = 100 ,
123- SortBy = 'VERSION' ,
124- SortOrder = 'DESCENDING' ,
125- NextToken = next_token
126- )
127- ecr_uri = resolve_ecr_uri_from_image_versions (sagemaker_session , response ['ImageVersions' ], image_name )
128- if "NextToken" in response :
129- next_token = response ["NextToken" ]
130-
131- if ecr_uri is not None :
132- return ecr_uri
133-
134- # Return error if no versions of the image found
135- error_message = (
136- f"No image version found for image name: { image_name } "
137- )
138- logger .error (error_message )
139- raise Exception (error_message )
140-
141- except (ClientError , sagemaker_session .sagemaker_client .exceptions .ResourceNotFound ) as e :
142- error_message = e .response ["Error" ]["Message" ]
143- logger .error (error_message )
144- raise Exception (error_message )
14580
14681def get_pipeline (
14782 region ,
@@ -167,12 +102,12 @@ def get_pipeline(
167102 default_bucket = sagemaker_session .default_bucket ()
168103 if role is None :
169104 role = sagemaker .session .get_execution_role (sagemaker_session )
170-
171- # parameters for pipeline execution
172- # processing_instance_count = ParameterInteger(name="ProcessingInstanceCount", default_value=1)
173- # processing_instance_type = ParameterString(
174- # name="ProcessingInstanceType", default_value="ml.m5.xlarge"
175- # )
105+ # parametersagemaker_sessions for pipeline execution
106+ sess = boto3 . Session ()
107+ processing_instance_count = ParameterInteger (name = "ProcessingInstanceCount" , default_value = 1 )
108+ processing_instance_type = ParameterString (
109+ name = "ProcessingInstanceType" , default_value = "ml.m5.xlarge"
110+ )
176111 training_instance_type = ParameterString (
177112 name = "TrainingInstanceType" , default_value = "ml.p2.xlarge"
178113 )
@@ -184,60 +119,40 @@ def get_pipeline(
184119 )
185120 input_data = ParameterString (
186121 name = "InputDataUrl" ,
187- default_value = "s3://{}/DEMO-paddle-byo/ " .format (default_bucket )
122+ default_value = "s3://{}/PaddleOCR/input/data " .format (default_bucket )
188123 )
124+ account = sess .client ("sts" ).get_caller_identity ()["Account" ]
125+ region = sess .region_name
126+ data_generate_image_name = "generate-ocr-train-data"
127+ train_image_name = "paddle"
128+ data_generate_image = "{}.dkr.ecr.{}.amazonaws.com/{}" .format (account , region , data_generate_image_name )
189129
190- training_image_name = "paddle"
191- inference_image_name = "paddle"
192-
193- # processing step for feature engineering
194- # try:
195- # processing_image_uri = sagemaker_session.sagemaker_client.describe_image_version(ImageName=processing_image_name)['ContainerImage']
196- # except (sagemaker_session.sagemaker_client.exceptions.ResourceNotFound):
197- # processing_image_uri = sagemaker.image_uris.retrieve(
198- # framework="xgboost",
199- # region=region,
200- # version="1.0-1",
201- # py_version="py3",
202- # instance_type=processing_instance_type,
203- # )
204- # script_processor = ScriptProcessor(
205- # image_uri=processing_image_uri,
206- # instance_type=processing_instance_type,
207- # instance_count=processing_instance_count,
208- # base_job_name=f"{base_job_prefix}/sklearn-abalone-preprocess",
209- # command=["python3"],
210- # sagemaker_session=sagemaker_session,
211- # role=role,
212- # )
213- # step_process = ProcessingStep(
214- # name="PreprocessAbaloneData",
215- # processor=script_processor,
216- # outputs=[
217- # ProcessingOutput(output_name="train", source="/opt/ml/processing/train"),
218- # ProcessingOutput(output_name="validation", source="/opt/ml/processing/validation"),
219- # ProcessingOutput(output_name="test", source="/opt/ml/processing/test"),
220- # ],
221- # code=os.path.join(BASE_DIR, "preprocess.py"),
222- # job_arguments=["--input-data", input_data],
223- # )
130+ script_processor = ScriptProcessor (
131+ image_uri = data_generate_image ,
132+ instance_type = processing_instance_type ,
133+ instance_count = processing_instance_count ,
134+ base_job_name = f"{ base_job_prefix } /paddle-ocr-data-generation" ,
135+ command = ["python3" ],
136+ sagemaker_session = sagemaker_session ,
137+ role = role ,
138+ )
139+ step_process = ProcessingStep (
140+ name = "GenerateOCRTrainingData" ,
141+ processor = script_processor ,
142+ outputs = [
143+ ProcessingOutput (output_name = "data" , source = "/opt/ml/processing/input/data" ),
144+ ],
145+ code = os .path .join (BASE_DIR , "preprocess.py" ),
146+ job_arguments = ["--input-data" , input_data ],
147+ )
224148
225149 # training step for generating model artifacts
226150 model_path = f"s3://{ sagemaker_session .default_bucket ()} /{ base_job_prefix } /PaddleOCRTrain"
227151
228- # try:
229- # print(training_image_name)
230- # training_image_uri = sagemaker_session.sagemaker_client.describe_image_version(ImageName=training_image_name)['ContainerImage']
231- # except (sagemaker_session.sagemaker_client.exceptions.ResourceNotFound):
232- # training_image_uri = sagemaker.image_uris.retrieve(
233- # framework="xgboost",
234- # region=region,
235- # version="1.0-1",
236- # py_version="py3",
237- # instance_type=training_instance_type,
238- # )
239-
240- training_image_uri = "230755935769.dkr.ecr.us-east-1.amazonaws.com/paddle:latest"
152+
153+ image = "{}.dkr.ecr.{}.amazonaws.com/{}" .format (account , region , train_image_name )
154+
155+ training_image_uri = image
241156 hyperparameters = {"epoch_num" : 10 ,
242157 "print_batch_step" :5 ,
243158 "save_epoch_step" :30 ,
@@ -252,6 +167,12 @@ def get_pipeline(
252167 sagemaker_session = sagemaker_session ,
253168 base_job_name = f"{ base_job_prefix } /paddleocr-train" ,
254169 hyperparameters = hyperparameters ,
170+ # acc: 0.2007992007992008, norm_edit_dis: 0.7116550116550118, fps: 97.10778964378831, best_epoch: 9
171+ metric_definitions = [
172+ {'Name' : 'validation:acc' , 'Regex' : '.*best metric,.*acc:(.*?),' },
173+ {'Name' : 'validation:norm_edit_dis' , 'Regex' : '.*best metric,.*norm_edit_dis:(.*?),' }
174+ ]
175+
255176 )
256177
257178
@@ -260,71 +181,17 @@ def get_pipeline(
260181 estimator = paddle_train ,
261182 inputs = {
262183 "training" : TrainingInput (
263- s3_data = input_data ,
264- content_type = "text/csv" ,
265- )
266- },
184+ s3_data = step_process .properties .ProcessingOutputConfig .Outputs [
185+ "data"
186+ ].S3Output .S3Uri ,
187+ content_type = "image/jpeg" )
188+ }
267189 )
268190
269- # processing step for evaluation
270- # script_eval = ScriptProcessor(
271- # image_uri=training_image_uri,
272- # command=["python3"],
273- # instance_type=processing_instance_type,
274- # instance_count=1,
275- # base_job_name=f"{base_job_prefix}/script-abalone-eval",
276- # sagemaker_session=sagemaker_session,
277- # role=role,
278- # )
279- # evaluation_report = PropertyFile(
280- # name="AbaloneEvaluationReport",
281- # output_name="evaluation",
282- # path="evaluation.json",
283- # )
284- # step_eval = ProcessingStep(
285- # name="EvaluateAbaloneModel",
286- # processor=script_eval,
287- # inputs=[
288- # ProcessingInput(
289- # source=step_train.properties.ModelArtifacts.S3ModelArtifacts,
290- # destination="/opt/ml/processing/model",
291- # ),
292- # ProcessingInput(
293- # source=step_process.properties.ProcessingOutputConfig.Outputs[
294- # "test"
295- # ].S3Output.S3Uri,
296- # destination="/opt/ml/processing/test",
297- # ),
298- # ],
299- # outputs=[
300- # ProcessingOutput(output_name="evaluation", source="/opt/ml/processing/evaluation"),
301- # ],
302- # code=os.path.join(BASE_DIR, "evaluate.py"),
303- # property_files=[evaluation_report],
304- # )
305-
306- # # register model step that will be conditionally executed
307- # model_metrics = ModelMetrics(
308- # model_statistics=MetricsSource(
309- # s3_uri="{}/evaluation.json".format(
310- # step_eval.arguments["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
311- # ),
312- # content_type="application/json"
313- # )
314- # )
315-
316- # try:
317- # inference_image_uri = sagemaker_session.sagemaker_client.describe_image_version(ImageName=inference_image_name)['ContainerImage']
318- # except (sagemaker_session.sagemaker_client.exceptions.ResourceNotFound):
319- # inference_image_uri = sagemaker.image_uris.retrieve(
320- # framework="xgboost",
321- # region=region,
322- # version="1.0-1",
323- # py_version="py3",
324- # instance_type=inference_instance_type,
325- # )
326-
327- inference_image_uri = "230755935769.dkr.ecr.us-east-1.amazonaws.com/paddle:latest"
191+
192+
193+
194+ inference_image_uri = image
328195 step_register = RegisterModel (
329196 name = "RegisterPaddleOCRModel" ,
330197 estimator = paddle_train ,
@@ -335,37 +202,34 @@ def get_pipeline(
335202 inference_instances = ["ml.p2.xlarge" ],
336203 transform_instances = ["ml.p2.xlarge" ],
337204 model_package_group_name = model_package_group_name ,
338- approval_status = model_approval_status ,
339- # model_metrics=model_metrics,
205+ approval_status = model_approval_status
206+ )
207+
208+ cond_lte = ConditionGreaterThanOrEqualTo ( # You can change the condition here
209+ left = step_train .properties .FinalMetricDataList [0 ].Value ,
210+ right = 0.8 , # You can change the threshold here
211+ )
212+
213+ step_cond = ConditionStep (
214+ name = "PaddleOCRAccuracyCond" ,
215+ conditions = [cond_lte ],
216+ if_steps = [step_register ],
217+ else_steps = [],
340218 )
341219
342- # condition step for evaluating model quality and branching execution
343- # cond_lte = ConditionLessThanOrEqualTo(
344- # left=JsonGet(
345- # step_name=step_eval.name,
346- # property_file=evaluation_report,
347- # json_path="regression_metrics.mse.value"
348- # ),
349- # right=6.0,
350- # )
351- # step_cond = ConditionStep(
352- # name="CheckMSEAbaloneEvaluation",
353- # conditions=[cond_lte],
354- # if_steps=[step_register],
355- # else_steps=[],
356- # )
357220
358221 # pipeline instance
359222 pipeline = Pipeline (
360223 name = pipeline_name ,
361224 parameters = [
362- # processing_instance_type,
363- # processing_instance_count,
225+ processing_instance_type ,
226+ processing_instance_count ,
364227 training_instance_type ,
365228 model_approval_status ,
366229 input_data ,
367230 ],
368- steps = [step_train , step_register ],
231+ steps = [step_process , step_train , step_cond ],
232+ # steps=[step_train, step_register],
369233 sagemaker_session = sagemaker_session ,
370234 )
371235 return pipeline
0 commit comments