Skip to content

Commit cb2782b

Browse files
authored
Improves lightgbm conversion speed (#491)
* improves lightgbm conversion speed
1 parent 3d81a0a commit cb2782b

File tree

5 files changed

+310
-57
lines changed

5 files changed

+310
-57
lines changed

.azure-pipelines/linux-conda-CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ jobs:
9797
displayName: 'Install dependencies'
9898
9999
- script: |
100+
pip install flake8
100101
python -m flake8 ./onnxmltools
101102
displayName: 'run flake8 check'
102103

onnxmltools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
This framework converts any machine learned model into onnx format
66
which is a common language to describe any machine learned model.
77
"""
8-
__version__ = "1.8.0"
8+
__version__ = "1.9.0"
99
__author__ = "Microsoft"
1010
__producer__ = "OnnxMLTools"
1111
__producer_version__ = __version__

onnxmltools/convert/lightgbm/_parse.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,49 @@ class WrappedBooster:
2121

2222
def __init__(self, booster):
2323
self.booster_ = booster
24-
_model_dict = self.booster_.dump_model()
25-
self.classes_ = self._generate_classes(_model_dict)
26-
self.n_features_ = len(_model_dict['feature_names'])
27-
if (_model_dict['objective'].startswith('binary') or
28-
_model_dict['objective'].startswith('multiclass')):
24+
self.n_features_ = self.booster_.feature_name()
25+
self.objective_ = self.get_objective()
26+
if self.objective_.startswith('binary'):
2927
self.operator_name = 'LgbmClassifier'
30-
elif _model_dict['objective'].startswith(('regression', 'poisson', 'gamma')):
28+
self.classes_ = self._generate_classes(booster)
29+
elif self.objective_.startswith('multiclass'):
30+
self.operator_name = 'LgbmClassifier'
31+
self.classes_ = self._generate_classes(booster)
32+
elif self.objective_.startswith('regression'):
3133
self.operator_name = 'LgbmRegressor'
3234
else:
33-
# Other objectives are not supported.
34-
raise ValueError("Unsupported LightGbm objective: '{}'.".format(_model_dict['objective']))
35-
if _model_dict.get('average_output', False):
35+
raise NotImplementedError(
36+
'Unsupported LightGbm objective: %r.' % self.objective_)
37+
average_output = self.booster_.attr('average_output')
38+
if average_output:
3639
self.boosting_type = 'rf'
3740
else:
3841
# Other than random forest, other boosting types do not affect later conversion.
3942
# Here `gbdt` is chosen for no reason.
4043
self.boosting_type = 'gbdt'
4144

42-
def _generate_classes(self, model_dict):
43-
if model_dict['num_class'] == 1:
45+
@staticmethod
46+
def _generate_classes(booster):
47+
if isinstance(booster, dict):
48+
num_class = booster['num_class']
49+
else:
50+
num_class = booster.attr('num_class')
51+
if num_class is None:
52+
dp = booster.dump_model(num_iteration=1)
53+
num_class = dp['num_class']
54+
if num_class == 1:
4455
return numpy.asarray([0, 1])
45-
return numpy.arange(model_dict['num_class'])
56+
return numpy.arange(num_class)
57+
58+
def get_objective(self):
59+
"Returns the objective."
60+
if hasattr(self, 'objective_') and self.objective_ is not None:
61+
return self.objective_
62+
objective = self.booster_.attr('objective')
63+
if objective is not None:
64+
return objective
65+
dp = self.booster_.dump_model(num_iteration=1)
66+
return dp['objective']
4667

4768

4869
def _get_lightgbm_operator_name(model):

0 commit comments

Comments
 (0)