1313from sagemaker .amazon .amazon_estimator import AmazonAlgorithmEstimatorBase , registry
1414from sagemaker .amazon .common import numpy_to_record_serializer , record_deserializer
1515from sagemaker .amazon .hyperparameter import Hyperparameter as hp # noqa
16- from sagemaker .amazon .validation import isin , gt , lt , isint , isbool , isnumber
16+ from sagemaker .amazon .validation import isin , gt , lt
1717from sagemaker .predictor import RealTimePredictor
1818from sagemaker .model import Model
1919from sagemaker .session import Session
@@ -27,40 +27,41 @@ class LinearLearner(AmazonAlgorithmEstimatorBase):
2727
2828 binary_classifier_model_selection_criteria = hp ('binary_classifier_model_selection_criteria' ,
2929 isin ('accuracy' , 'f1' , 'precision_at_target_recall' ,
30- 'recall_at_target_precision' , 'cross_entropy_loss' ))
31- target_recall = hp ('target_recall' , (gt (0 ), lt (1 )), "A float in (0,1)" )
32- target_precision = hp ('target_precision' , (gt (0 ), lt (1 )), "A float in (0,1)" )
33- positive_example_weight_mult = hp ('positive_example_weight_mult' , gt (0 ), "A float greater than 0" )
34- epochs = hp ('epochs' , (gt (0 ), isint ), "An integer greater-than 0" )
30+ 'recall_at_target_precision' , 'cross_entropy_loss' ),
31+ data_type = str )
32+ target_recall = hp ('target_recall' , (gt (0 ), lt (1 )), "A float in (0,1)" , float )
33+ target_precision = hp ('target_precision' , (gt (0 ), lt (1 )), "A float in (0,1)" , float )
34+ positive_example_weight_mult = hp ('positive_example_weight_mult' , gt (0 ), "A float greater than 0" , float )
35+ epochs = hp ('epochs' , gt (0 ), "An integer greater-than 0" , int )
3536 predictor_type = hp ('predictor_type' , isin ('binary_classifier' , 'regressor' ),
36- 'One of "binary_classifier" or "regressor"' )
37- use_bias = hp ('use_bias' , isbool , "Either True or False" )
38- num_models = hp ('num_models' , ( gt (0 ), isint ), "An integer greater-than 0" )
39- num_calibration_samples = hp ('num_calibration_samples' , ( gt (0 ), isint ), "An integer greater-than 0" )
40- init_method = hp ('init_method' , isin ('uniform' , 'normal' ), 'One of "uniform" or "normal"' )
41- init_scale = hp ('init_scale' , (gt (- 1 ), lt (1 )), 'A float in (-1, 1)' )
42- init_sigma = hp ('init_sigma' , (gt (0 ), lt (1 )), 'A float in (0, 1)' )
43- init_bias = hp ('init_bias' , isnumber , 'A number' )
44- optimizer = hp ('optimizer' , isin ('sgd' , 'adam' , 'auto' ), 'One of "sgd", "adam" or "auto' )
37+ 'One of "binary_classifier" or "regressor"' , str )
38+ use_bias = hp ('use_bias' , () , "Either True or False" , bool )
39+ num_models = hp ('num_models' , gt (0 ), "An integer greater-than 0" , int )
40+ num_calibration_samples = hp ('num_calibration_samples' , gt (0 ), "An integer greater-than 0" , int )
41+ init_method = hp ('init_method' , isin ('uniform' , 'normal' ), 'One of "uniform" or "normal"' , str )
42+ init_scale = hp ('init_scale' , (gt (- 1 ), lt (1 )), 'A float in (-1, 1)' , float )
43+ init_sigma = hp ('init_sigma' , (gt (0 ), lt (1 )), 'A float in (0, 1)' , float )
44+ init_bias = hp ('init_bias' , () , 'A number' , float )
45+ optimizer = hp ('optimizer' , isin ('sgd' , 'adam' , 'auto' ), 'One of "sgd", "adam" or "auto' , str )
4546 loss = hp ('loss' , isin ('logistic' , 'squared_loss' , 'absolute_loss' , 'auto' ),
46- '"logistic", "squared_loss", "absolute_loss" or"auto"' )
47- wd = hp ('wd' , (gt (0 ), lt (1 )), 'A float in (0,1)' )
48- l1 = hp ('l1' , (gt (0 ), lt (1 )), 'A float in (0,1)' )
49- momentum = hp ('momentum' , (gt (0 ), lt (1 )), 'A float in (0,1)' )
50- learning_rate = hp ('learning_rate' , (gt (0 ), lt (1 )), 'A float in (0,1)' )
51- beta_1 = hp ('beta_1' , (gt (0 ), lt (1 )), 'A float in (0,1)' )
52- beta_2 = hp ('beta_1' , (gt (0 ), lt (1 )), 'A float in (0,1)' )
53- bias_lr_mult = hp ('bias_lr_mult' , gt (0 ), 'A float greater-than 0' )
54- bias_wd_mult = hp ('bias_wd_mult' , gt (0 ), 'A float greater-than 0' )
55- use_lr_scheduler = hp ('use_lr_scheduler' , isbool , 'A boolean' )
56- lr_scheduler_step = hp ('lr_scheduler_step' , ( gt (0 ), isint ), 'An integer greater-than 0' )
57- lr_scheduler_factor = hp ('lr_scheduler_factor' , (gt (0 ), lt (1 )), 'A float in (0,1)' )
58- lr_scheduler_minimum_lr = hp ('lr_scheduler_minimum_lr' , gt (0 ), 'A float greater-than 0' )
59- normalize_data = hp ('normalize_data' , isbool , 'A boolean' )
60- normalize_label = hp ('normalize_label' , isbool , 'A boolean' )
61- unbias_data = hp ('unbias_data' , isbool , 'A boolean' )
62- unbias_label = hp ('unbias_label' , isbool , 'A boolean' )
63- num_point_for_scalar = hp ('num_point_for_scalar' , ( isint , gt (0 )) , 'An integer greater-than 0' )
47+ '"logistic", "squared_loss", "absolute_loss" or"auto"' , str )
48+ wd = hp ('wd' , (gt (0 ), lt (1 )), 'A float in (0,1)' , float )
49+ l1 = hp ('l1' , (gt (0 ), lt (1 )), 'A float in (0,1)' , float )
50+ momentum = hp ('momentum' , (gt (0 ), lt (1 )), 'A float in (0,1)' , float )
51+ learning_rate = hp ('learning_rate' , (gt (0 ), lt (1 )), 'A float in (0,1)' , float )
52+ beta_1 = hp ('beta_1' , (gt (0 ), lt (1 )), 'A float in (0,1)' , float )
53+ beta_2 = hp ('beta_1' , (gt (0 ), lt (1 )), 'A float in (0,1)' , float )
54+ bias_lr_mult = hp ('bias_lr_mult' , gt (0 ), 'A float greater-than 0' , float )
55+ bias_wd_mult = hp ('bias_wd_mult' , gt (0 ), 'A float greater-than 0' , float )
56+ use_lr_scheduler = hp ('use_lr_scheduler' , () , 'A boolean' , bool )
57+ lr_scheduler_step = hp ('lr_scheduler_step' , gt (0 ), 'An integer greater-than 0' , int )
58+ lr_scheduler_factor = hp ('lr_scheduler_factor' , (gt (0 ), lt (1 )), 'A float in (0,1)' , float )
59+ lr_scheduler_minimum_lr = hp ('lr_scheduler_minimum_lr' , gt (0 ), 'A float greater-than 0' , float )
60+ normalize_data = hp ('normalize_data' , () , 'A boolean' , bool )
61+ normalize_label = hp ('normalize_label' , () , 'A boolean' , bool )
62+ unbias_data = hp ('unbias_data' , () , 'A boolean' , bool )
63+ unbias_label = hp ('unbias_label' , () , 'A boolean' , bool )
64+ num_point_for_scalar = hp ('num_point_for_scalar' , gt (0 ), 'An integer greater-than 0' , int )
6465
6566 def __init__ (self , role , train_instance_count , train_instance_type , predictor_type = 'binary_classifier' ,
6667 binary_classifier_model_selection_criteria = None , target_recall = None , target_precision = None ,
0 commit comments