Skip to content

Commit d2c8fb7

Browse files
hongzmsftxadupre
authored andcommitted
feat: add support for lightgbm.Booster (#329)
For users who need to train a LightGBM model on huge datasets that do not fit the memory, they often choose to train through the Booster API, which supports two_round loading of huge datasets, or the command line, which exports a model.txt file that can only be reconstructed to a Booster. Currently the ONNX converter does not support either case. Most of the converter work is about processing information from the dumped model dictionary. Other information can also be inferred from that information as well. Here we wrap the Booster information to facilitate the conversion process. The multiclass model is not supported yet. The exported ONNX model has an issue with its ZipMap node.
1 parent 6ac6a11 commit d2c8fb7

File tree

6 files changed

+124
-27
lines changed

6 files changed

+124
-27
lines changed

onnxmltools/convert/lightgbm/_parse.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
# Licensed under the MIT License. See License.txt in the project root for
44
# license information.
55
# --------------------------------------------------------------------------
6+
import numpy
67

78
from ..common._container import LightGbmModelContainer
89
from ..common._topology import *
10+
from ..common.data_types import FloatTensorType
911

1012
from lightgbm import LGBMClassifier, LGBMRegressor
1113

@@ -16,14 +18,44 @@
1618
lightgbm_operator_name_map = {LGBMClassifier: 'LgbmClassifier',
1719
LGBMRegressor: 'LgbmRegressor'}
1820

19-
20-
def _get_lightgbm_operator_name(model_type):
21+
class WrappedBooster:
22+
23+
def __init__(self, booster):
24+
self.booster_ = booster
25+
_model_dict = self.booster_.dump_model()
26+
self.classes_ = self._generate_classes(_model_dict)
27+
self.n_features_ = len(_model_dict['feature_names'])
28+
if _model_dict['objective'].startswith('binary'):
29+
self.operator_name = 'LgbmClassifier'
30+
elif _model_dict['objective'].startswith('regression'):
31+
self.operator_name = 'LgbmRegressor'
32+
else:
33+
# Multiclass classifier is not supported at the moment. The exported ONNX model
34+
# has an issue in ZipMap node.
35+
raise ValueError('unsupported LightGbm objective: {}'.format(_model_dict['objective']))
36+
if _model_dict.get('average_output', False):
37+
self.boosting_type = 'rf'
38+
else:
39+
# Other than random forest, other boosting types do not affect later conversion.
40+
# Here `gbdt` is chosen for no reason.
41+
self.boosting_type = 'gbdt'
42+
43+
def _generate_classes(self, model_dict):
44+
if model_dict['num_class'] == 1:
45+
return numpy.asarray([0, 1])
46+
return numpy.arange(model_dict['num_class'])
47+
48+
49+
def _get_lightgbm_operator_name(model):
2150
'''
2251
Get operator name of the input argument
2352
24-
:param model_type: A lightgbm object.
53+
:param model: A lightgbm object.
2554
:return: A string which stands for the type of the input model in our conversion framework
2655
'''
56+
if isinstance(model, WrappedBooster):
57+
return model.operator_name
58+
model_type = type(model)
2759
if model_type not in lightgbm_operator_name_map:
2860
raise ValueError("No proper operator name found for '%s'" % model_type)
2961
return lightgbm_operator_name_map[model_type]
@@ -38,10 +70,11 @@ def _parse_lightgbm_simple_model(scope, model, inputs):
3870
:param inputs: A list of variables
3971
:return: A list of output variables which will be passed to next stage
4072
'''
41-
this_operator = scope.declare_local_operator(_get_lightgbm_operator_name(type(model)), model)
73+
operator_name = _get_lightgbm_operator_name(model)
74+
this_operator = scope.declare_local_operator(operator_name, model)
4275
this_operator.inputs = inputs
4376

44-
if type(model) in lightgbm_classifier_list:
77+
if operator_name == 'LgbmClassifier':
4578
# For classifiers, we may have two outputs, one for label and the other one for probabilities of all classes.
4679
# Notice that their types here are not necessarily correct and they will be fixed in shape inference phase
4780
label_variable = scope.declare_local_variable('label', FloatTensorType())

onnxmltools/convert/lightgbm/convert.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
# --------------------------------------------------------------------------
66

77
from uuid import uuid4
8+
9+
import lightgbm
10+
11+
from onnxmltools.convert.lightgbm._parse import WrappedBooster
812
from ...proto import onnx, get_opset_number_from_onnx
913
from ..common._topology import convert_topology
1014
from ._parse import parse_lightgbm
@@ -21,10 +25,11 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=No
2125
This function produces an equivalent ONNX model of the given lightgbm model.
2226
The supported lightgbm modules are listed below.
2327
24-
* `LGBMClassifiers <http://lightgbm.readthedocs.io/en/latest/Python-API.html#lightgbm.LGBMClassifier>`_
25-
* `LGBMRegressor <http://lightgbm.readthedocs.io/en/latest/Python-API.html#lightgbm.LGBMRegressor>`_
28+
* `LGBMClassifiers <https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html>`_
29+
* `LGBMRegressor <https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMRegressor.html>`_
30+
* `Booster <https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html>`_
2631
27-
:param model: A lightgbm model
32+
:param model: A LightGBM model
2833
:param initial_types: a python list. Each element is a tuple of a variable name and a type defined in data_types.py
2934
:param name: The name of the graph (type: GraphProto) in the produced ONNX model (type: ModelProto)
3035
:param doc_string: A string attached onto the produced ONNX model
@@ -36,8 +41,10 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=No
3641
:return: An ONNX model (type: ModelProto) which is equivalent to the input lightgbm model
3742
'''
3843
if initial_types is None:
39-
raise ValueError('Initial types are required. See usage of convert(...) in \
40-
onnxmltools.convert.lightgbm.convert for details')
44+
raise ValueError('Initial types are required. See usage of convert(...) in '
45+
'onnxmltools.convert.lightgbm.convert for details')
46+
if isinstance(model, lightgbm.Booster):
47+
model = WrappedBooster(model)
4148
if name is None:
4249
name = str(uuid4().hex)
4350

onnxmltools/convert/lightgbm/operator_converters/LightGbm.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,10 @@ def _parse_tree_structure(tree_id, class_id, learning_rate, tree_structure, attr
7070
else:
7171
attrs['nodes_missing_value_tracks_true'].append(0)
7272
attrs['nodes_hitrates'].append(1.)
73-
_parse_node(tree_id, class_id, left_id, node_id_pool, learning_rate, tree_structure['left_child'], attrs)
74-
_parse_node(tree_id, class_id, right_id, node_id_pool, learning_rate, tree_structure['right_child'], attrs)
73+
_parse_node(tree_id, class_id, left_id, node_id_pool, learning_rate,
74+
tree_structure['left_child'], attrs)
75+
_parse_node(tree_id, class_id, right_id, node_id_pool, learning_rate,
76+
tree_structure['right_child'], attrs)
7577

7678

7779
def _parse_node(tree_id, class_id, node_id, node_id_pool, learning_rate, node, attrs):
@@ -97,8 +99,10 @@ def _parse_node(tree_id, class_id, node_id, node_id_pool, learning_rate, node, a
9799
attrs['nodes_hitrates'].append(1.)
98100

99101
# Recursively dive into the child nodes
100-
_parse_node(tree_id, class_id, left_id, node_id_pool, learning_rate, node['left_child'], attrs)
101-
_parse_node(tree_id, class_id, right_id, node_id_pool, learning_rate, node['right_child'], attrs)
102+
_parse_node(tree_id, class_id, left_id, node_id_pool, learning_rate, node['left_child'],
103+
attrs)
104+
_parse_node(tree_id, class_id, right_id, node_id_pool, learning_rate, node['right_child'],
105+
attrs)
102106
elif hasattr(node, 'left_child') or hasattr(node, 'right_child'):
103107
raise ValueError('Need two branches')
104108
else:
@@ -130,19 +134,20 @@ def convert_lightgbm(scope, operator, container):
130134

131135
attrs = get_default_tree_classifier_attribute_pairs()
132136
attrs['name'] = operator.full_name
133-
137+
134138
# Create different attributes for classifier and regressor, respectively
135-
if isinstance(gbm_model, LGBMClassifier):
139+
if gbm_text['objective'].startswith('binary'):
140+
n_classes = 1
141+
attrs['post_transform'] = 'LOGISTIC'
142+
elif gbm_text['objective'].startswith('multiclass'):
136143
n_classes = gbm_text['num_class']
137-
if gbm_model.objective_ == 'multiclass':
138-
attrs['post_transform'] = 'SOFTMAX'
139-
else:
140-
attrs['post_transform'] = 'LOGISTIC'
141-
else:
144+
attrs['post_transform'] = 'SOFTMAX'
145+
elif gbm_text['objective'].startswith('regression'):
142146
n_classes = 1 # Regressor has only one output variable
143147
attrs['post_transform'] = 'NONE'
144148
attrs['n_targets'] = n_classes
145-
149+
else:
150+
assert False, 'LightGBM objective should be cleaned already'
146151
# Use the same algorithm to parse the tree
147152
for i, tree in enumerate(gbm_text['tree_info']):
148153
tree_id = i
@@ -156,7 +161,8 @@ def convert_lightgbm(scope, operator, container):
156161
tree_number = len(node_numbers_per_tree.keys())
157162
accumulated_node_numbers = [0] * tree_number
158163
for i in range(1, tree_number):
159-
accumulated_node_numbers[i] = accumulated_node_numbers[i - 1] + node_numbers_per_tree[i - 1]
164+
accumulated_node_numbers[i] = (accumulated_node_numbers[i - 1]
165+
+ node_numbers_per_tree[i - 1])
160166
global_node_indexes = []
161167
for i in range(len(attrs['nodes_nodeids'])):
162168
tree_id = attrs['nodes_treeids'][i]
@@ -169,7 +175,8 @@ def convert_lightgbm(scope, operator, container):
169175
attrs[k] = sorted_list
170176

171177
# Create ONNX object
172-
if isinstance(gbm_model, LGBMClassifier):
178+
if (gbm_text['objective'].startswith('binary')
179+
or gbm_text['objective'].startswith('multiclass')):
173180
# Prepare label information for both of TreeEnsembleClassifier and ZipMap
174181
class_type = onnx_proto.TensorProto.STRING
175182
zipmap_attrs = {'name': scope.get_unique_variable_name('ZipMap')}

onnxmltools/utils/tests_helper.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,19 @@ def dump_data_and_model(data, model, onnx=None, basename="model", folder=None,
7575
os.makedirs(folder)
7676

7777
if hasattr(model, "predict"):
78-
if hasattr(model, "predict_proba"):
78+
import lightgbm
79+
if isinstance(model, lightgbm.Booster):
80+
# LightGBM Booster
81+
model_dict = model.dump_model()
82+
if model_dict['objective'].startswith('binary'):
83+
score = model.predict(data)
84+
prediction = [score > 0.5, numpy.vstack([1-score, score]).T]
85+
elif model_dict['objective'].startswith('multiclass'):
86+
score = model.predict(data)
87+
prediction = [score.argmax(axis=1), score]
88+
else:
89+
prediction = [model.predict(data)]
90+
elif hasattr(model, "predict_proba"):
7991
# Classifier
8092
prediction = [model.predict(data), model.predict_proba(data)]
8193
elif hasattr(model, "decision_function"):
@@ -172,6 +184,13 @@ def convert_model(model, name, input_types):
172184
elif model.__class__.__name__.startswith("XGB"):
173185
from onnxmltools.convert import convert_xgboost
174186
model, prefix = convert_xgboost(model, name, input_types), "XGB"
187+
elif model.__class__.__name__ == 'Booster':
188+
import lightgbm
189+
if isinstance(model, lightgbm.Booster):
190+
from onnxmltools.convert import convert_lightgbm
191+
model, prefix = convert_lightgbm(model, name, input_types), "LightGbm"
192+
else:
193+
raise RuntimeError("Unable to convert model of type '{0}'.".format(type(model)))
175194
elif isinstance(model, BaseEstimator):
176195
from onnxmltools.convert import convert_sklearn
177196
model, prefix = convert_sklearn(model, name, input_types), "Sklearn"

onnxmltools/utils/utils_backend_onnxruntime.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import os
55
import glob
66
import pickle
7+
import warnings
8+
79
import numpy
810
from numpy.testing import assert_array_almost_equal, assert_array_equal
911
from .utils_backend import load_data_and_model, extract_options, ExpectedAssertionError, OnnxRuntimeAssertionError, compare_outputs

tests/lightgbm/test_LightGbmTreeEnsembleConverters.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
# --------------------------------------------------------------------------
66

77
import unittest
8+
9+
import lightgbm
810
import numpy
911
from lightgbm import LGBMClassifier, LGBMRegressor
10-
from onnxmltools import convert_lightgbm
1112
from onnxmltools.convert.common.data_types import FloatTensorType
1213
from onnxmltools.utils import dump_data_and_model
1314
from onnxmltools.utils import dump_binary_classification, dump_multiple_classification
14-
from onnxmltools.utils import dump_multiple_regression, dump_single_regression
15+
from onnxmltools.utils import dump_single_regression
16+
from onnxmltools.utils.tests_helper import convert_model
1517

1618

1719
class TestLightGbmTreeEnsembleModels(unittest.TestCase):
@@ -33,6 +35,33 @@ def test_lightgbm_regressor2(self):
3335
model = LGBMRegressor(n_estimators=2, max_depth=1, min_child_samples=1)
3436
dump_single_regression(model, suffix="2")
3537

38+
def test_lightgbm_booster_classifier(self):
39+
X = [[0, 1], [1, 1], [2, 0], [1, 2]]
40+
X = numpy.array(X, dtype=numpy.float32)
41+
y = [0, 1, 0, 1]
42+
data = lightgbm.Dataset(X, label=y)
43+
model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'binary',
44+
'n_estimators': 3, 'min_child_samples': 1},
45+
data)
46+
model_onnx, prefix = convert_model(model, 'tree-based multi-output classifier',
47+
[('input', FloatTensorType([1, 2]))])
48+
dump_data_and_model(X, model, model_onnx,
49+
allow_failure="StrictVersion(onnx.__version__) < StrictVersion('1.3.0')",
50+
basename=prefix + "BoosterBin" + model.__class__.__name__)
51+
52+
def test_lightgbm_booster_regressor(self):
53+
X = [[0, 1], [1, 1], [2, 0]]
54+
X = numpy.array(X, dtype=numpy.float32)
55+
y = [0, 1, 1.1]
56+
data = lightgbm.Dataset(X, label=y)
57+
model = lightgbm.train({'boosting_type': 'gbdt', 'objective': 'regression',
58+
'n_estimators': 3, 'min_child_samples': 1, 'max_depth': 1},
59+
data)
60+
model_onnx, prefix = convert_model(model, 'tree-based binary classifier',
61+
[('input', FloatTensorType([1, 2]))])
62+
dump_data_and_model(X, model, model_onnx,
63+
basename=prefix + "BoosterBin" + model.__class__.__name__)
64+
3665

3766
if __name__ == "__main__":
3867
unittest.main()

0 commit comments

Comments
 (0)