1212
1313class TestXGBoostModels (unittest .TestCase ):
1414
15- @unittest .skipIf (sys .version_info [0 ] == 2 , reason = "xgboost converted not tested on python 2" )
15+ @unittest .skipIf (sys .version_info [0 ] == 2 , reason = "xgboost converter not tested on python 2" )
1616 def test_xgb_regressor (self ):
1717 iris = load_iris ()
1818 X = iris .data [:, :2 ]
@@ -24,7 +24,7 @@ def test_xgb_regressor(self):
2424 self .assertTrue (conv_model is not None )
2525 dump_single_regression (xgb , suffix = "-Dec4" )
2626
27- @unittest .skipIf (sys .version_info [0 ] == 2 , reason = "xgboost converted not tested on python 2" )
27+ @unittest .skipIf (sys .version_info [0 ] == 2 , reason = "xgboost converter not tested on python 2" )
2828 def test_xgb_classifier (self ):
2929 iris = load_iris ()
3030 X = iris .data [:, :2 ]
@@ -35,9 +35,9 @@ def test_xgb_classifier(self):
3535 xgb .fit (X , y )
3636 conv_model = convert_xgboost (xgb , initial_types = [('input' , FloatTensorType (shape = [1 , 'None' ]))])
3737 self .assertTrue (conv_model is not None )
38- dump_binary_classification (xgb , verbose = True )
38+ dump_binary_classification (xgb )
3939
40- @unittest .skipIf (sys .version_info [0 ] == 2 , reason = "xgboost converted not tested on python 2" )
40+ @unittest .skipIf (sys .version_info [0 ] == 2 , reason = "xgboost converter not tested on python 2" )
4141 def test_xgb_classifier_multi (self ):
4242 iris = load_iris ()
4343 X = iris .data [:, :2 ]
@@ -49,6 +49,32 @@ def test_xgb_classifier_multi(self):
4949 self .assertTrue (conv_model is not None )
5050 dump_multiple_classification (xgb , allow_failure = "StrictVersion(onnx.__version__) < StrictVersion('1.3.0')" )
5151
52+ @unittest .skipIf (sys .version_info [0 ] == 2 , reason = "xgboost converter not tested on python 2" )
53+ def test_xgb_classifier_multi_reglog (self ):
54+ iris = load_iris ()
55+ X = iris .data [:, :2 ]
56+ y = iris .target
57+
58+ xgb = XGBClassifier (objective = 'reg:logistic' )
59+ xgb .fit (X , y )
60+ conv_model = convert_xgboost (xgb , initial_types = [('input' , FloatTensorType (shape = [1 , 2 ]))])
61+ self .assertTrue (conv_model is not None )
62+ dump_multiple_classification (xgb , suffix = "RegLog" ,
63+ allow_failure = "StrictVersion(onnx.__version__) < StrictVersion('1.3.0')" )
64+
65+ @unittest .skipIf (sys .version_info [0 ] == 2 , reason = "xgboost converter not tested on python 2" )
66+ def test_xgb_classifier_reglog (self ):
67+ iris = load_iris ()
68+ X = iris .data [:, :2 ]
69+ y = iris .target
70+ y [y == 2 ] = 0
71+
72+ xgb = XGBClassifier (objective = 'reg:logistic' )
73+ xgb .fit (X , y )
74+ conv_model = convert_xgboost (xgb , initial_types = [('input' , FloatTensorType (shape = [1 , 2 ]))])
75+ self .assertTrue (conv_model is not None )
76+ dump_binary_classification (xgb , suffix = "RegLog" )
77+
5278
5379if __name__ == "__main__" :
5480 unittest .main ()
0 commit comments