@@ -79,7 +79,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass
7979 self .sagemaker_session = sagemaker_session or Session ()
8080
8181 def transform (self , data , data_type = 'S3Prefix' , content_type = None , compression_type = None , split_type = None ,
82- job_name = None ):
82+ job_name = None , input_filter = None , output_filter = None , join_source = None ):
8383 """Start a new transform job.
8484
8585 Args:
@@ -97,6 +97,15 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
9797 split_type (str): The record delimiter for the input object (default: 'None').
9898 Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
9999 job_name (str): job name (default: None). If not specified, one will be generated.
100+ input_filter (str): A JSONPath to select a portion of the input to pass to the algorithm container for
101+ inference. If you omit the field, it gets the value '$', representing the entire input.
102+ Some examples: "$[1:]", "$.features"(default: None).
103+ output_filter (str): A JSONPath to select a portion of the joined/original output to return as the output.
104+ Some examples: "$[1:]", "$.prediction" (default: None).
105+ join_source (str): The source of data to be joined to the transform output. It can be set to 'Input'
106+ meaning the entire input record will be joined to the inference result.
107+ You can use OutputFilter to select the useful portion before uploading to S3. (default: None).
108+ Valid values: Input, None.
100109 """
101110 local_mode = self .sagemaker_session .local_mode
102111 if not local_mode and not data .startswith ('s3://' ):
@@ -116,7 +125,7 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
116125 self .output_path = 's3://{}/{}' .format (self .sagemaker_session .default_bucket (), self ._current_job_name )
117126
118127 self .latest_transform_job = _TransformJob .start_new (self , data , data_type , content_type , compression_type ,
119- split_type )
128+ split_type , input_filter , output_filter , join_source )
120129
121130 def delete_model (self ):
122131 """Delete the corresponding SageMaker model for this Transformer.
@@ -214,16 +223,19 @@ def _prepare_init_params_from_job_description(cls, job_details):
214223
215224class _TransformJob (_Job ):
216225 @classmethod
217- def start_new (cls , transformer , data , data_type , content_type , compression_type , split_type ):
226+ def start_new (cls , transformer , data , data_type , content_type , compression_type ,
227+ split_type , input_filter , output_filter , join_source ):
218228 config = _TransformJob ._load_config (data , data_type , content_type , compression_type , split_type , transformer )
229+ data_processing = _TransformJob ._prepare_data_processing (input_filter , output_filter , join_source )
219230
220231 transformer .sagemaker_session .transform (job_name = transformer ._current_job_name ,
221232 model_name = transformer .model_name , strategy = transformer .strategy ,
222233 max_concurrent_transforms = transformer .max_concurrent_transforms ,
223234 max_payload = transformer .max_payload , env = transformer .env ,
224235 input_config = config ['input_config' ],
225236 output_config = config ['output_config' ],
226- resource_config = config ['resource_config' ], tags = transformer .tags )
237+ resource_config = config ['resource_config' ],
238+ tags = transformer .tags , data_processing = data_processing )
227239
228240 return cls (transformer .sagemaker_session , transformer ._current_job_name )
229241
@@ -287,3 +299,21 @@ def _prepare_resource_config(instance_count, instance_type, volume_kms_key):
287299 config ['VolumeKmsKeyId' ] = volume_kms_key
288300
289301 return config
302+
303+ @staticmethod
304+ def _prepare_data_processing (input_filter , output_filter , join_source ):
305+ config = {}
306+
307+ if input_filter is not None :
308+ config ['InputFilter' ] = input_filter
309+
310+ if output_filter is not None :
311+ config ['OutputFilter' ] = output_filter
312+
313+ if join_source is not None :
314+ config ['JoinSource' ] = join_source
315+
316+ if len (config ) == 0 :
317+ return None
318+
319+ return config
0 commit comments