1919from sagemaker .amazon import validation
2020from sagemaker .amazon .hyperparameter import Hyperparameter as hp # noqa
2121from sagemaker .amazon .common import write_numpy_to_dense_tensor
22- from sagemaker .estimator import EstimatorBase
22+ from sagemaker .estimator import EstimatorBase , _TrainingJob
2323from sagemaker .session import s3_input
2424from sagemaker .utils import sagemaker_timestamp
2525
@@ -92,11 +92,38 @@ def _prepare_init_params_from_job_description(cls, job_details):
9292 del init_params ['image' ]
9393 return init_params
9494
95- def fit (self , records , mini_batch_size = None , ** kwargs ):
95+ def _prepare_for_training (self , records , mini_batch_size = None , job_name = None ):
96+ """Set hyperparameters needed for training.
97+
98+ Args:
99+ * records (:class:`~RecordSet`): The records to train this ``Estimator`` on.
100+ * mini_batch_size (int or None): The size of each mini-batch to use when training. If ``None``, a
101+ default value will be used.
102+ * job_name (str): Name of the training job to be created. If not specified, one is generated,
103+ using the base name given to the constructor if applicable.
104+ """
105+ super (AmazonAlgorithmEstimatorBase , self )._prepare_for_training (job_name = job_name )
106+
107+ feature_dim = None
108+
109+ if isinstance (records , list ):
110+ for record in records :
111+ if record .channel == 'train' :
112+ feature_dim = record .feature_dim
113+ break
114+ if feature_dim is None :
115+ raise ValueError ('Must provide train channel.' )
116+ else :
117+ feature_dim = records .feature_dim
118+
119+ self .feature_dim = feature_dim
120+ self .mini_batch_size = mini_batch_size
121+
122+ def fit (self , records , mini_batch_size = None , wait = True , logs = True , job_name = None ):
96123 """Fit this Estimator on serialized Record objects, stored in S3.
97124
98125 ``records`` should be an instance of :class:`~RecordSet`. This defines a collection of
99- s3 data files to train this ``Estimator`` on.
126+ S3 data files to train this ``Estimator`` on.
100127
101128 Training data is expected to be encoded as dense or sparse vectors in the "values" feature
102129 on each Record. If the data is labeled, the label is expected to be encoded as a list of
@@ -110,15 +137,19 @@ def fit(self, records, mini_batch_size=None, **kwargs):
110137
111138 Args:
112139 records (:class:`~RecordSet`): The records to train this ``Estimator`` on
113- mini_batch_size (int or None): The size of each mini-batch to use when training. If None, a
140+ mini_batch_size (int or None): The size of each mini-batch to use when training. If `` None`` , a
114141 default value will be used.
142+ wait (bool): Whether the call should wait until the job completes (default: True).
143+ logs (bool): Whether to show the logs produced by the job.
144+ Only meaningful when wait is True (default: True).
145+ job_name (str): Training job name. If not specified, the estimator generates a default job name,
146+ based on the training image name and current timestamp.
115147 """
116- self .feature_dim = records .feature_dim
117- self .mini_batch_size = mini_batch_size
148+ self ._prepare_for_training (records , job_name = job_name , mini_batch_size = mini_batch_size )
118149
119- data = { records . channel : s3_input ( records . s3_data , distribution = 'ShardedByS3Key' ,
120- s3_data_type = records . s3_data_type )}
121- super ( AmazonAlgorithmEstimatorBase , self ). fit ( data , ** kwargs )
150+ self . latest_training_job = _TrainingJob . start_new ( self , records )
151+ if wait :
152+ self . latest_training_job . wait ( logs = logs )
122153
123154 def record_set (self , train , labels = None , channel = "train" ):
124155 """Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector.
@@ -180,6 +211,14 @@ def __repr__(self):
180211 """Return an unambiguous representation of this RecordSet"""
181212 return str ((RecordSet , self .__dict__ ))
182213
214+ def data_channel (self ):
215+ """Return a dictionary to represent the training data in a channel for use with ``fit()``"""
216+ return {self .channel : self .records_s3_input ()}
217+
218+ def records_s3_input (self ):
219+ """Return a s3_input to represent the training data"""
220+ return s3_input (self .s3_data , distribution = 'ShardedByS3Key' , s3_data_type = self .s3_data_type )
221+
183222
184223def _build_shards (num_shards , array ):
185224 if num_shards < 1 :
0 commit comments