1414from onnxmltools .convert import convert_xgboost
1515from onnxmltools .convert .common .data_types import FloatTensorType
1616from onnxmltools .utils import dump_data_and_model
17+ from onnxruntime import InferenceSession
1718
1819
1920def _fit_classification_model (model , n_classes , is_str = False , dtype = None ):
@@ -31,8 +32,6 @@ def _fit_classification_model(model, n_classes, is_str=False, dtype=None):
3132
3233class TestXGBoostModels (unittest .TestCase ):
3334
34- @unittest .skipIf (sys .version_info [0 ] == 2 ,
35- reason = "xgboost converter not tested on python 2" )
3635 def test_xgb_regressor (self ):
3736 iris = load_diabetes ()
3837 x = iris .data
@@ -42,7 +41,7 @@ def test_xgb_regressor(self):
4241 xgb = XGBRegressor ()
4342 xgb .fit (x_train , y_train )
4443 conv_model = convert_xgboost (
45- xgb , initial_types = [('input' , FloatTensorType (shape = [' None' , ' None' ]))])
44+ xgb , initial_types = [('input' , FloatTensorType (shape = [None , None ]))])
4645 self .assertTrue (conv_model is not None )
4746 dump_data_and_model (
4847 x_test .astype ("float32" ),
@@ -54,12 +53,10 @@ def test_xgb_regressor(self):
5453 "< StrictVersion('1.3.0')" ,
5554 )
5655
57- @unittest .skipIf (sys .version_info [0 ] == 2 ,
58- reason = "xgboost converter not tested on python 2" )
5956 def test_xgb_classifier (self ):
6057 xgb , x_test = _fit_classification_model (XGBClassifier (), 2 )
6158 conv_model = convert_xgboost (
62- xgb , initial_types = [('input' , FloatTensorType (shape = [' None' , ' None' ]))])
59+ xgb , initial_types = [('input' , FloatTensorType (shape = [None , None ]))])
6360 self .assertTrue (conv_model is not None )
6461 dump_data_and_model (
6562 x_test ,
@@ -71,8 +68,6 @@ def test_xgb_classifier(self):
7168 "< StrictVersion('1.3.0')" ,
7269 )
7370
74- @unittest .skipIf (sys .version_info [0 ] == 2 ,
75- reason = "xgboost converter not tested on python 2" )
7671 def test_xgb_classifier_uint8 (self ):
7772 xgb , x_test = _fit_classification_model (
7873 XGBClassifier (), 2 , dtype = np .uint8 )
@@ -89,12 +84,10 @@ def test_xgb_classifier_uint8(self):
8984 "< StrictVersion('1.3.0')" ,
9085 )
9186
92- @unittest .skipIf (sys .version_info [0 ] == 2 ,
93- reason = "xgboost converter not tested on python 2" )
9487 def test_xgb_classifier_multi (self ):
9588 xgb , x_test = _fit_classification_model (XGBClassifier (), 3 )
9689 conv_model = convert_xgboost (
97- xgb , initial_types = [('input' , FloatTensorType (shape = [' None' , ' None' ]))])
90+ xgb , initial_types = [('input' , FloatTensorType (shape = [None , None ]))])
9891 self .assertTrue (conv_model is not None )
9992 dump_data_and_model (
10093 x_test ,
@@ -106,13 +99,11 @@ def test_xgb_classifier_multi(self):
10699 "< StrictVersion('1.3.0')" ,
107100 )
108101
109- @unittest .skipIf (sys .version_info [0 ] == 2 ,
110- reason = "xgboost converter not tested on python 2" )
111102 def test_xgb_classifier_multi_reglog (self ):
112103 xgb , x_test = _fit_classification_model (
113104 XGBClassifier (objective = 'reg:logistic' ), 4 )
114105 conv_model = convert_xgboost (
115- xgb , initial_types = [('input' , FloatTensorType (shape = [' None' , ' None' ]))])
106+ xgb , initial_types = [('input' , FloatTensorType (shape = [None , None ]))])
116107 self .assertTrue (conv_model is not None )
117108 dump_data_and_model (
118109 x_test ,
@@ -124,13 +115,11 @@ def test_xgb_classifier_multi_reglog(self):
124115 "< StrictVersion('1.3.0')" ,
125116 )
126117
127- @unittest .skipIf (sys .version_info [0 ] == 2 ,
128- reason = "xgboost converter not tested on python 2" )
129118 def test_xgb_classifier_reglog (self ):
130119 xgb , x_test = _fit_classification_model (
131120 XGBClassifier (objective = 'reg:logistic' ), 2 )
132121 conv_model = convert_xgboost (
133- xgb , initial_types = [('input' , FloatTensorType (shape = [' None' , ' None' ]))])
122+ xgb , initial_types = [('input' , FloatTensorType (shape = [None , None ]))])
134123 self .assertTrue (conv_model is not None )
135124 dump_data_and_model (
136125 x_test ,
@@ -142,13 +131,11 @@ def test_xgb_classifier_reglog(self):
142131 "< StrictVersion('1.3.0')" ,
143132 )
144133
145- @unittest .skipIf (sys .version_info [0 ] == 2 ,
146- reason = "xgboost converter not tested on python 2" )
147134 def test_xgb_classifier_multi_str_labels (self ):
148135 xgb , x_test = _fit_classification_model (
149136 XGBClassifier (n_estimators = 4 ), 5 , is_str = True )
150137 conv_model = convert_xgboost (
151- xgb , initial_types = [('input' , FloatTensorType (shape = [' None' , ' None' ]))])
138+ xgb , initial_types = [('input' , FloatTensorType (shape = [None , None ]))])
152139 self .assertTrue (conv_model is not None )
153140 dump_data_and_model (
154141 x_test ,
@@ -160,8 +147,6 @@ def test_xgb_classifier_multi_str_labels(self):
160147 "< StrictVersion('1.3.0')" ,
161148 )
162149
163- @unittest .skipIf (sys .version_info [0 ] == 2 ,
164- reason = "xgboost converter not tested on python 2" )
165150 def test_xgb_classifier_multi_discrete_int_labels (self ):
166151 iris = load_iris ()
167152 x = iris .data [:, :2 ]
@@ -176,7 +161,7 @@ def test_xgb_classifier_multi_discrete_int_labels(self):
176161 xgb = XGBClassifier (n_estimators = 3 )
177162 xgb .fit (x_train , y_train )
178163 conv_model = convert_xgboost (
179- xgb , initial_types = [('input' , FloatTensorType (shape = [' None' , ' None' ]))])
164+ xgb , initial_types = [('input' , FloatTensorType (shape = [None , None ]))])
180165 self .assertTrue (conv_model is not None )
181166 dump_data_and_model (
182167 x_test .astype ("float32" ),
@@ -188,8 +173,6 @@ def test_xgb_classifier_multi_discrete_int_labels(self):
188173 "< StrictVersion('1.3.0')" ,
189174 )
190175
191- @unittest .skipIf (sys .version_info [0 ] == 2 ,
192- reason = "xgboost converter not tested on python 2" )
193176 def test_xgboost_booster_classifier_bin (self ):
194177 x , y = make_classification (n_classes = 2 , n_features = 5 ,
195178 n_samples = 100 ,
@@ -207,8 +190,6 @@ def test_xgboost_booster_classifier_bin(self):
207190 allow_failure = "StrictVersion(onnx.__version__) < StrictVersion('1.3.0')" ,
208191 basename = "XGBBoosterMCl" )
209192
210- @unittest .skipIf (sys .version_info [0 ] == 2 ,
211- reason = "xgboost converter not tested on python 2" )
212193 def test_xgboost_booster_classifier_multiclass (self ):
213194 x , y = make_classification (n_classes = 3 , n_features = 5 ,
214195 n_samples = 100 ,
@@ -227,8 +208,6 @@ def test_xgboost_booster_classifier_multiclass(self):
227208 allow_failure = "StrictVersion(onnx.__version__) < StrictVersion('1.3.0')" ,
228209 basename = "XGBBoosterMCl" )
229210
230- @unittest .skipIf (sys .version_info [0 ] == 2 ,
231- reason = "xgboost converter not tested on python 2" )
232211 def test_xgboost_booster_classifier_reg (self ):
233212 x , y = make_classification (n_classes = 2 , n_features = 5 ,
234213 n_samples = 100 ,
@@ -247,8 +226,6 @@ def test_xgboost_booster_classifier_reg(self):
247226 allow_failure = "StrictVersion(onnx.__version__) < StrictVersion('1.3.0')" ,
248227 basename = "XGBBoosterReg" )
249228
250- @unittest .skipIf (sys .version_info [0 ] == 2 ,
251- reason = "xgboost converter not tested on python 2" )
252229 def test_xgboost_10 (self ):
253230 this = os .path .abspath (os .path .dirname (__file__ ))
254231 train = os .path .join (this , "input_fail_train.csv" )
@@ -282,6 +259,29 @@ def test_xgboost_10(self):
282259 allow_failure = "StrictVersion(onnx.__version__) < StrictVersion('1.3.0')" ,
283260 basename = "XGBBoosterRegBug" )
284261
262+ def test_xgboost_classifier_i5450 (self ):
263+ iris = load_iris ()
264+ X , y = iris .data , iris .target
265+ X_train , X_test , y_train , y_test = train_test_split (X , y , random_state = 10 )
266+ clr = XGBClassifier (objective = "multi:softmax" , max_depth = 1 , n_estimators = 2 )
267+ clr .fit (X_train , y_train , eval_set = [(X_test , y_test )], early_stopping_rounds = 40 )
268+ initial_type = [('float_input' , FloatTensorType ([None , 4 ]))]
269+ onx = convert_xgboost (clr , initial_types = initial_type )
270+ sess = InferenceSession (onx .SerializeToString ())
271+ input_name = sess .get_inputs ()[0 ].name
272+ label_name = sess .get_outputs ()[1 ].name
273+ predict_list = [1. , 20. , 466. , 0. ]
274+ predict_array = np .array (predict_list ).reshape ((1 ,- 1 )).astype (np .float32 )
275+ pred_onx = sess .run ([label_name ], {input_name : predict_array })[0 ]
276+ pred_xgboost = sessresults = clr .predict_proba (predict_array )
277+ bst = clr .get_booster ()
278+ bst .dump_model ('dump.raw.txt' )
279+ dump_data_and_model (
280+ X_test .astype (np .float32 ) + 1e-5 ,
281+ clr , onx ,
282+ allow_failure = "StrictVersion(onnx.__version__) < StrictVersion('1.3.0')" ,
283+ basename = "XGBClassifierIris" )
284+
285285 def test_xgboost_example_mnist (self ):
286286 """
287287 Train a simple xgboost model and store associated artefacts.
0 commit comments