1515from sagemaker .amazon .amazon_estimator import AmazonAlgorithmEstimatorBase , registry
1616from sagemaker .amazon .common import numpy_to_record_serializer , record_deserializer
1717from sagemaker .amazon .hyperparameter import Hyperparameter as hp # noqa
18- from sagemaker .amazon .validation import isin , gt , lt , ge
18+ from sagemaker .amazon .validation import isin , gt , lt , ge , le
1919from sagemaker .predictor import RealTimePredictor
2020from sagemaker .model import Model
2121from sagemaker .session import Session
@@ -28,28 +28,28 @@ class LinearLearner(AmazonAlgorithmEstimatorBase):
2828 DEFAULT_MINI_BATCH_SIZE = 1000
2929
3030 binary_classifier_model_selection_criteria = hp ('binary_classifier_model_selection_criteria' ,
31- isin ('accuracy' , 'f1' , 'precision_at_target_recall' ,
32- 'recall_at_target_precision' , 'cross_entropy_loss' ) ,
33- data_type = str )
31+ isin ('accuracy' , 'f1' , 'f_beta' , ' precision_at_target_recall' ,
32+ 'recall_at_target_precision' , 'cross_entropy_loss' ,
33+ 'loss_function' ), data_type = str )
3434 target_recall = hp ('target_recall' , (gt (0 ), lt (1 )), "A float in (0,1)" , float )
3535 target_precision = hp ('target_precision' , (gt (0 ), lt (1 )), "A float in (0,1)" , float )
3636 positive_example_weight_mult = hp ('positive_example_weight_mult' , (),
3737 "A float greater than 0 or 'auto' or 'balanced'" , str )
3838 epochs = hp ('epochs' , gt (0 ), "An integer greater-than 0" , int )
39- predictor_type = hp ('predictor_type' , isin ('binary_classifier' , 'regressor' ),
40- 'One of "binary_classifier" or "regressor"' , str )
39+ predictor_type = hp ('predictor_type' , isin ('binary_classifier' , 'regressor' , 'multiclass_classifier' ),
40+ 'One of "binary_classifier" or "multiclass_classifier" or " regressor"' , str )
4141 use_bias = hp ('use_bias' , (), "Either True or False" , bool )
4242 num_models = hp ('num_models' , gt (0 ), "An integer greater-than 0" , int )
4343 num_calibration_samples = hp ('num_calibration_samples' , gt (0 ), "An integer greater-than 0" , int )
4444 init_method = hp ('init_method' , isin ('uniform' , 'normal' ), 'One of "uniform" or "normal"' , str )
4545 init_scale = hp ('init_scale' , gt (0 ), 'A float greater-than 0' , float )
4646 init_sigma = hp ('init_sigma' , gt (0 ), 'A float greater-than 0' , float )
4747 init_bias = hp ('init_bias' , (), 'A number' , float )
48- optimizer = hp ('optimizer' , isin ('sgd' , 'adam' , 'auto' ), 'One of "sgd", "adam" or "auto' , str )
48+ optimizer = hp ('optimizer' , isin ('sgd' , 'adam' , 'rmsprop' , ' auto' ), 'One of "sgd", "adam", "rmsprop " or "auto' , str )
4949 loss = hp ('loss' , isin ('logistic' , 'squared_loss' , 'absolute_loss' , 'hinge_loss' , 'eps_insensitive_squared_loss' ,
50- 'eps_insensitive_absolute_loss' , 'quantile_loss' , 'huber_loss' , 'auto' ),
50+ 'eps_insensitive_absolute_loss' , 'quantile_loss' , 'huber_loss' , 'softmax_loss' , ' auto' ),
5151 '"logistic", "squared_loss", "absolute_loss", "hinge_loss", "eps_insensitive_squared_loss", '
52- '"eps_insensitive_absolute_loss", "quantile_loss", "huber_loss" or "auto"' , str )
52+ '"eps_insensitive_absolute_loss", "quantile_loss", "huber_loss", "softmax_loss" or "auto"' , str )
5353 wd = hp ('wd' , ge (0 ), 'A float greater-than or equal to 0' , float )
5454 l1 = hp ('l1' , ge (0 ), 'A float greater-than or equal to 0' , float )
5555 momentum = hp ('momentum' , (ge (0 ), lt (1 )), 'A float in [0,1)' , float )
@@ -73,6 +73,10 @@ class LinearLearner(AmazonAlgorithmEstimatorBase):
7373 huber_delta = hp ('huber_delta' , ge (0 ), 'A float greater-than or equal to 0' , float )
7474 early_stopping_patience = hp ('early_stopping_patience' , gt (0 ), 'An integer greater-than 0' , int )
7575 early_stopping_tolerance = hp ('early_stopping_tolerance' , gt (0 ), 'A float greater-than 0' , float )
76+ num_classes = hp ('num_classes' , (gt (0 ), le (1000000 )), 'An integer in [1,1000000]' , int )
77+ accuracy_top_k = hp ('accuracy_top_k' , (gt (0 ), le (1000000 )), 'An integer in [1,1000000]' , int )
78+ f_beta = hp ('f_beta' , gt (0 ), 'A float greater-than 0' , float )
79+ balance_multiclass_weights = hp ('balance_multiclass_weights' , (), 'A boolean' , bool )
7680
7781 def __init__ (self , role , train_instance_count , train_instance_type , predictor_type ,
7882 binary_classifier_model_selection_criteria = None , target_recall = None , target_precision = None ,
@@ -83,7 +87,8 @@ def __init__(self, role, train_instance_count, train_instance_type, predictor_ty
8387 lr_scheduler_factor = None , lr_scheduler_minimum_lr = None , normalize_data = None ,
8488 normalize_label = None , unbias_data = None , unbias_label = None , num_point_for_scaler = None , margin = None ,
8589 quantile = None , loss_insensitivity = None , huber_delta = None , early_stopping_patience = None ,
86- early_stopping_tolerance = None , ** kwargs ):
90+ early_stopping_tolerance = None , num_classes = None , accuracy_top_k = None , f_beta = None ,
91+ balance_multiclass_weights = None , ** kwargs ):
8792 """An :class:`Estimator` for binary classification and regression.
8893
8994 Amazon SageMaker Linear Learner provides a solution for both classification and regression problems, allowing
@@ -119,9 +124,10 @@ def __init__(self, role, train_instance_count, train_instance_type, predictor_ty
119124 the inference code might use the IAM role, if accessing AWS resource.
120125 train_instance_count (int): Number of Amazon EC2 instances to use for training.
121126 train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
122- predictor_type (str): The type of predictor to learn. Either "binary_classifier" or "regressor".
123- binary_classifier_model_selection_criteria (str): One of 'accuracy', 'f1', 'precision_at_target_recall',
124- 'recall_at_target_precision', 'cross_entropy_loss'
127+ predictor_type (str): The type of predictor to learn. Either "binary_classifier" or
128+ "multiclass_classifier" or "regressor".
129+ binary_classifier_model_selection_criteria (str): One of 'accuracy', 'f1', 'f_beta',
130+ 'precision_at_target_recall', 'recall_at_target_precision', 'cross_entropy_loss', 'loss_function'
125131 target_recall (float): Target recall. Only applicable if binary_classifier_model_selection_criteria is
126132 precision_at_target_recall.
127133 target_precision (float): Target precision. Only applicable if binary_classifier_model_selection_criteria
@@ -139,9 +145,10 @@ def __init__(self, role, train_instance_count, train_instance_type, predictor_ty
139145 init_scale (float): For "uniform" init, the range of values.
140146 init_sigma (float): For "normal" init, the standard-deviation.
141147 init_bias (float): Initial weight for bias term
142- optimizer (str): One of 'sgd', 'adam' or 'auto'
148+ optimizer (str): One of 'sgd', 'adam', 'rmsprop' or 'auto'
143149 loss (str): One of 'logistic', 'squared_loss', 'absolute_loss', 'hinge_loss',
144- 'eps_insensitive_squared_loss', 'eps_insensitive_absolute_loss', 'quantile_loss', 'huber_loss' or 'auto'
150+ 'eps_insensitive_squared_loss', 'eps_insensitive_absolute_loss', 'quantile_loss', 'huber_loss' or
151+ 'softmax_loss' or 'auto'.
145152 wd (float): L2 regularization parameter i.e. the weight decay parameter. Use 0 for no L2 regularization.
146153 l1 (float): L1 regularization parameter. Use 0 for no L1 regularization.
147154 momentum (float): Momentum parameter of sgd optimizer.
@@ -180,6 +187,15 @@ def __init__(self, role, train_instance_count, train_instance_type, predictor_ty
180187 early_stopping_tolerance (float): Relative tolerance to measure an improvement in loss. If the ratio of
181188 the improvement in loss divided by the previous best loss is smaller than this value, early stopping will
182189 consider the improvement to be zero.
190+ num_classes (int): The number of classes for the response variable. Required when predictor_type is
191+ multiclass_classifier and ignored otherwise. The classes are assumed to be labeled 0, ..., num_classes - 1.
192+ accuracy_top_k (int): The value of k when computing the Top K Accuracy metric for multiclass
193+ classification. An example is scored as correct if the model assigns one of the top k scores to the true
194+ label.
195+ f_beta (float): The value of beta to use when calculating F score metrics for binary or multiclass
196+ classification. Also used if binary_classifier_model_selection_criteria is f_beta.
197+ balance_multiclass_weights (bool): Whether to use class weights which give each class equal importance in
198+ the loss function. Only used when predictor_type is multiclass_classifier.
183199 **kwargs: base class keyword argument values.
184200 """
185201 super (LinearLearner , self ).__init__ (role , train_instance_count , train_instance_type , ** kwargs )
@@ -221,6 +237,14 @@ def __init__(self, role, train_instance_count, train_instance_type, predictor_ty
221237 self .huber_delta = huber_delta
222238 self .early_stopping_patience = early_stopping_patience
223239 self .early_stopping_tolerance = early_stopping_tolerance
240+ self .num_classes = num_classes
241+ self .accuracy_top_k = accuracy_top_k
242+ self .f_beta = f_beta
243+ self .balance_multiclass_weights = balance_multiclass_weights
244+
245+ if self .predictor_type == 'multiclass_classifier' and (num_classes is None or num_classes < 3 ):
246+ raise ValueError (
247+ "For predictor_type 'multiclass_classifier', 'num_classes' should be set to a value greater than 2." )
224248
225249 def create_model (self ):
226250 """Return a :class:`~sagemaker.amazon.kmeans.LinearLearnerModel` referencing the latest
0 commit comments