@@ -1222,7 +1222,7 @@ class s3_input(object):
12221222
12231223 def __init__ (self , s3_data , distribution = 'FullyReplicated' , compression = None ,
12241224 content_type = None , record_wrapping = None , s3_data_type = 'S3Prefix' ,
1225- input_mode = None ):
1225+ input_mode = None , attribute_names = None , shuffle_config = None ):
12261226 """Create a definition for input data used by an SageMaker training job.
12271227
12281228 See AWS documentation on the ``CreateTrainingJob`` API for more details on the parameters.
@@ -1234,17 +1234,23 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None,
12341234 compression (str): Valid values: 'Gzip', None (default: None). This is used only in Pipe input mode.
12351235 content_type (str): MIME type of the input data (default: None).
12361236 record_wrapping (str): Valid values: 'RecordIO' (default: None).
1237- s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile'. If 'S3Prefix', ``s3_data`` defines
1238- a prefix of s3 objects to train on. All objects with s3 keys beginning with ``s3_data`` will
1239- be used to train. If 'ManifestFile', then ``s3_data`` defines a single s3 manifest file, listing
1240- each s3 object to train on. The Manifest file format is described in the SageMaker API documentation:
1241- https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html
1237+ s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile', 'AugmentedManifestFile'. If 'S3Prefix',
1238+ ``s3_data`` defines a prefix of s3 objects to train on. All objects with s3 keys beginning with
1239+ ``s3_data`` will be used to train. If 'ManifestFile' or 'AugmentedManifestFile', then ``s3_data``
1240+ defines a single s3 manifest file or augmented manifest file (respectively), listing the s3 data to
1241+ train on. Both the ManifestFile and AugmentedManifestFile formats are described in the SageMaker API
1242+ documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html
12421243 input_mode (str): Optional override for this channel's input mode (default: None). By default, channels will
12431244 use the input mode defined on ``sagemaker.estimator.EstimatorBase.input_mode``, but they will ignore
12441245 that setting if this parameter is set.
12451246 * None - Amazon SageMaker will use the input mode specified in the ``Estimator``.
12461247 * 'File' - Amazon SageMaker copies the training dataset from the S3 location to a local directory.
12471248 * 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe.
1249+ attribute_names (list[str]): A list of one or more attribute names to use that are found in a specified
1250+ AugmentedManifestFile.
1251+ shuffle_config (ShuffleConfig): If specified this configuration enables shuffling on this channel. See the
1252+ SageMaker API documentation for more info:
1253+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
12481254 """
12491255 self .config = {
12501256 'DataSource' : {
@@ -1264,6 +1270,24 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None,
12641270 self .config ['RecordWrapperType' ] = record_wrapping
12651271 if input_mode is not None :
12661272 self .config ['InputMode' ] = input_mode
1273+ if attribute_names is not None :
1274+ self .config ['DataSource' ]['S3DataSource' ]['AttributeNames' ] = attribute_names
1275+ if shuffle_config is not None :
1276+ self .config ['ShuffleConfig' ] = {'Seed' : shuffle_config .seed }
1277+
1278+
1279+ class ShuffleConfig (object ):
1280+ """
1281+ Used to configure channel shuffling using a seed. See SageMaker
1282+ documentation for more detail: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
1283+ """
1284+ def __init__ (self , seed ):
1285+ """
1286+ Create a ShuffleConfig.
1287+ Args:
1288+ seed (long): the long value used to seed the shuffled sequence.
1289+ """
1290+ self .seed = seed
12671291
12681292
12691293class ModelContainer (object ):
0 commit comments