@@ -21,28 +21,49 @@ class WrappedBooster:
2121
2222 def __init__ (self , booster ):
2323 self .booster_ = booster
24- _model_dict = self .booster_ .dump_model ()
25- self .classes_ = self ._generate_classes (_model_dict )
26- self .n_features_ = len (_model_dict ['feature_names' ])
27- if (_model_dict ['objective' ].startswith ('binary' ) or
28- _model_dict ['objective' ].startswith ('multiclass' )):
24+ self .n_features_ = self .booster_ .feature_name ()
25+ self .objective_ = self .get_objective ()
26+ if self .objective_ .startswith ('binary' ):
2927 self .operator_name = 'LgbmClassifier'
30- elif _model_dict ['objective' ].startswith (('regression' , 'poisson' , 'gamma' )):
28+ self .classes_ = self ._generate_classes (booster )
29+ elif self .objective_ .startswith ('multiclass' ):
30+ self .operator_name = 'LgbmClassifier'
31+ self .classes_ = self ._generate_classes (booster )
32+ elif self .objective_ .startswith ('regression' ):
3133 self .operator_name = 'LgbmRegressor'
3234 else :
33- # Other objectives are not supported.
34- raise ValueError ("Unsupported LightGbm objective: '{}'." .format (_model_dict ['objective' ]))
35- if _model_dict .get ('average_output' , False ):
35+ raise NotImplementedError (
36+ 'Unsupported LightGbm objective: %r.' % self .objective_ )
37+ average_output = self .booster_ .attr ('average_output' )
38+ if average_output :
3639 self .boosting_type = 'rf'
3740 else :
3841 # Other than random forest, other boosting types do not affect later conversion.
3942 # Here `gbdt` is chosen for no reason.
4043 self .boosting_type = 'gbdt'
4144
42- def _generate_classes (self , model_dict ):
43- if model_dict ['num_class' ] == 1 :
45+ @staticmethod
46+ def _generate_classes (booster ):
47+ if isinstance (booster , dict ):
48+ num_class = booster ['num_class' ]
49+ else :
50+ num_class = booster .attr ('num_class' )
51+ if num_class is None :
52+ dp = booster .dump_model (num_iteration = 1 )
53+ num_class = dp ['num_class' ]
54+ if num_class == 1 :
4455 return numpy .asarray ([0 , 1 ])
45- return numpy .arange (model_dict ['num_class' ])
56+ return numpy .arange (num_class )
57+
58+ def get_objective (self ):
59+ "Returns the objective."
60+ if hasattr (self , 'objective_' ) and self .objective_ is not None :
61+ return self .objective_
62+ objective = self .booster_ .attr ('objective' )
63+ if objective is not None :
64+ return objective
65+ dp = self .booster_ .dump_model (num_iteration = 1 )
66+ return dp ['objective' ]
4667
4768
4869def _get_lightgbm_operator_name (model ):
0 commit comments