@@ -269,7 +269,14 @@ def fit(
269269 if wait :
270270 self .latest_training_job .wait (logs = logs )
271271
272- def record_set (self , train , labels = None , channel = "train" , encrypt = False ):
272+ def record_set (
273+ self ,
274+ train ,
275+ labels = None ,
276+ channel = "train" ,
277+ encrypt = False ,
278+ distribution = "ShardedByS3Key" ,
279+ ):
273280 """Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector.
274281
275282 For the 2D ``ndarray`` ``train``, each row is converted to a
@@ -294,6 +301,8 @@ def record_set(self, train, labels=None, channel="train", encrypt=False):
294301 should be assigned to.
295302 encrypt (bool): Specifies whether the objects uploaded to S3 are
296303 encrypted on the server side using AES-256 (default: ``False``).
304+ distribution (str): The SageMaker TrainingJob channel s3 data
305+ distribution type (default: ``ShardedByS3Key``).
297306
298307 Returns:
299308 RecordSet: A RecordSet referencing the encoded, uploading training
@@ -316,6 +325,7 @@ def record_set(self, train, labels=None, channel="train", encrypt=False):
316325 num_records = train .shape [0 ],
317326 feature_dim = train .shape [1 ],
318327 channel = channel ,
328+ distribution = distribution ,
319329 )
320330
321331 def _get_default_mini_batch_size (self , num_records : int ):
@@ -343,6 +353,7 @@ def __init__(
343353 feature_dim : int ,
344354 s3_data_type : Union [str , PipelineVariable ] = "ManifestFile" ,
345355 channel : Union [str , PipelineVariable ] = "train" ,
356+ distribution : str = "ShardedByS3Key" ,
346357 ):
347358 """A collection of Amazon :class:~`Record` objects serialized and stored in S3.
348359
@@ -358,12 +369,15 @@ def __init__(
358369 single s3 manifest file, listing each s3 object to train on.
359370 channel (str or PipelineVariable): The SageMaker Training Job channel this RecordSet
360371 should be bound to
372+ distribution (str): The SageMaker TrainingJob S3 data distribution type.
373+ Valid values: 'ShardedByS3Key', 'FullyReplicated'.
361374 """
362375 self .s3_data = s3_data
363376 self .feature_dim = feature_dim
364377 self .num_records = num_records
365378 self .s3_data_type = s3_data_type
366379 self .channel = channel
380+ self .distribution = distribution
367381
368382 def __repr__ (self ):
369383 """Return an unambiguous representation of this RecordSet"""
@@ -377,7 +391,7 @@ def data_channel(self):
377391 def records_s3_input (self ):
378392 """Return a TrainingInput to represent the training data"""
379393 return TrainingInput (
380- self .s3_data , distribution = "ShardedByS3Key" , s3_data_type = self .s3_data_type
394+ self .s3_data , distribution = self . distribution , s3_data_type = self .s3_data_type
381395 )
382396
383397
0 commit comments