@@ -671,6 +671,9 @@ def transform_config(
671671 compression_type = None ,
672672 split_type = None ,
673673 job_name = None ,
674+ input_filter = None ,
675+ output_filter = None ,
676+ join_source = None ,
674677):
675678 """Export Airflow transform config from a SageMaker transformer
676679
@@ -686,13 +689,38 @@ def transform_config(
686689
687690 * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object
688691 to use as an input for the transform job.
692+
689693 content_type (str): MIME type of the input data (default: None).
690694 compression_type (str): Compression type of the input data, if
691695 compressed (default: None). Valid values: 'Gzip', None.
692696 split_type (str): The record delimiter for the input object (default:
693697 'None'). Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
694698 job_name (str): job name (default: None). If not specified, one will be
695699 generated.
700+ input_filter (str): A JSONPath to select a portion of the input to
701+ pass to the algorithm container for inference. If you omit the
702+ field, it gets the value '$', representing the entire input.
703+ For CSV data, each row is taken as a JSON array,
704+ so only index-based JSONPaths can be applied, e.g. $[0], $[1:].
705+ CSV data should follow the `RFC format <https://tools.ietf.org/html/rfc4180>`_.
706+ See `Supported JSONPath Operators
707+ <https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform-data-processing.html#data-processing-operators>`_
708+ for a table of supported JSONPath operators.
709+ For more information, see the SageMaker API documentation for
710+ `CreateTransformJob
711+ <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
712+ Some examples: "$[1:]", "$.features" (default: None).
713+ output_filter (str): A JSONPath to select a portion of the
714+ joined/original output to return as the output.
715+ For more information, see the SageMaker API documentation for
716+ `CreateTransformJob
717+ <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
718+ Some examples: "$[1:]", "$.prediction" (default: None).
719+ join_source (str): The source of data to be joined to the transform
720+ output. It can be set to 'Input' meaning the entire input record
721+ will be joined to the inference result. You can use OutputFilter
722+ to select the useful portion before uploading to S3. (default:
723+ None). Valid values: Input, None.
696724
697725 Returns:
698726 dict: Transform config that can be directly used by
@@ -723,6 +751,12 @@ def transform_config(
723751 "TransformResources" : job_config ["resource_config" ],
724752 }
725753
754+ data_processing = sagemaker .transformer ._TransformJob ._prepare_data_processing (
755+ input_filter , output_filter , join_source
756+ )
757+ if data_processing is not None :
758+ config ["DataProcessing" ] = data_processing
759+
726760 if transformer .strategy is not None :
727761 config ["BatchStrategy" ] = transformer .strategy
728762
@@ -768,6 +802,9 @@ def transform_config_from_estimator(
768802 model_server_workers = None ,
769803 image = None ,
770804 vpc_config_override = None ,
805+ input_filter = None ,
806+ output_filter = None ,
807+ join_source = None ,
771808):
772809 """Export Airflow transform config from a SageMaker estimator
773810
@@ -836,9 +873,35 @@ def transform_config_from_estimator(
836873 image (str): An container image to use for deploying the model
837874 vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on
838875 the model. Default: use subnets and security groups from this Estimator.
876+
839877 * 'Subnets' (list[str]): List of subnet ids.
840878 * 'SecurityGroupIds' (list[str]): List of security group ids.
841879
880+ input_filter (str): A JSONPath to select a portion of the input to
881+ pass to the algorithm container for inference. If you omit the
882+ field, it gets the value '$', representing the entire input.
883+ For CSV data, each row is taken as a JSON array,
884+ so only index-based JSONPaths can be applied, e.g. $[0], $[1:].
885+ CSV data should follow the `RFC format <https://tools.ietf.org/html/rfc4180>`_.
886+ See `Supported JSONPath Operators
887+ <https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform-data-processing.html#data-processing-operators>`_
888+ for a table of supported JSONPath operators.
889+ For more information, see the SageMaker API documentation for
890+ `CreateTransformJob
891+ <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
892+ Some examples: "$[1:]", "$.features" (default: None).
893+ output_filter (str): A JSONPath to select a portion of the
894+ joined/original output to return as the output.
895+ For more information, see the SageMaker API documentation for
896+ `CreateTransformJob
897+ <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
898+ Some examples: "$[1:]", "$.prediction" (default: None).
899+ join_source (str): The source of data to be joined to the transform
900+ output. It can be set to 'Input' meaning the entire input record
901+ will be joined to the inference result. You can use OutputFilter
902+ to select the useful portion before uploading to S3. (default:
903+ None). Valid values: Input, None.
904+
842905 Returns:
843906 dict: Transform config that can be directly used by
844907 SageMakerTransformOperator in Airflow.
@@ -891,7 +954,16 @@ def transform_config_from_estimator(
891954 transformer .model_name = model_base_config ["ModelName" ]
892955
893956 transform_base_config = transform_config (
894- transformer , data , data_type , content_type , compression_type , split_type , job_name
957+ transformer ,
958+ data ,
959+ data_type ,
960+ content_type ,
961+ compression_type ,
962+ split_type ,
963+ job_name ,
964+ input_filter ,
965+ output_filter ,
966+ join_source ,
895967 )
896968
897969 config = {"Model" : model_base_config , "Transform" : transform_base_config }
0 commit comments