Skip to content

Commit f2c69bb

Browse files
dante.lJ-tt
authored andcommitted
support LGBMRanker conversion
Signed-off-by: Jett Jackson <[email protected]>
1 parent fe55a8a commit f2c69bb

File tree

4 files changed

+17
-2
lines changed

4 files changed

+17
-2
lines changed

onnxmltools/convert/lightgbm/_parse.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
from ..common.data_types import (FloatTensorType,
88
SequenceType, DictionaryType, StringType, Int64Type)
99

10-
from lightgbm import LGBMClassifier, LGBMRegressor
10+
from lightgbm import LGBMClassifier, LGBMRegressor, LGBMRanker
1111

1212
lightgbm_classifier_list = [LGBMClassifier]
1313

1414
# Associate scikit-learn types with our operator names. If two scikit-learn models share a single name, it means their
1515
# are equivalent in terms of conversion.
1616
lightgbm_operator_name_map = {LGBMClassifier: 'LgbmClassifier',
17-
LGBMRegressor: 'LgbmRegressor'}
17+
LGBMRegressor: 'LgbmRegressor',
18+
LGBMRanker: 'LgbmRanker'}
1819

1920

2021
class WrappedBooster:
@@ -31,6 +32,8 @@ def __init__(self, booster):
3132
self.classes_ = self._generate_classes(booster)
3233
elif self.objective_.startswith('regression'):
3334
self.operator_name = 'LgbmRegressor'
35+
elif self.objective_.startswith('lambdarank'):
36+
self.operator_name = 'LgbmRanker'
3437
else:
3538
raise NotImplementedError(
3639
'Unsupported LightGbm objective: %r.' % self.objective_)

onnxmltools/convert/lightgbm/operator_converters/LightGbm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,10 @@ def convert_lightgbm(scope, operator, container):
455455
# so we need to add an 'Exp' post transform node to the model
456456
attrs['post_transform'] = 'NONE'
457457
post_transform = "Exp"
458+
elif gbm_text['objective'].startswith('lambdarank'):
459+
n_classes = 1 # Ranker has only one output variable
460+
attrs['post_transform'] = 'NONE'
461+
attrs['n_targets'] = n_classes
458462
else:
459463
raise RuntimeError(
460464
"LightGBM objective should be cleaned already not '{}'.".format(
@@ -818,3 +822,4 @@ def convert_lgbm_zipmap(scope, operator, container):
818822
register_converter('LgbmClassifier', convert_lightgbm)
819823
register_converter('LgbmRegressor', convert_lightgbm)
820824
register_converter('LgbmZipMap', convert_lgbm_zipmap)
825+
register_converter('LgbmRanker', convert_lightgbm)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from ...common._registration import register_shape_calculator
4+
from ...common.shape_calculator import calculate_linear_regressor_output_shapes
5+
6+
register_shape_calculator('LgbmRanker', calculate_linear_regressor_output_shapes)

onnxmltools/convert/lightgbm/shape_calculators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
# To register shape calculators for lightgbm operators, import associated modules here.
44
from . import Classifier
55
from . import Regressor
6+
from . import Ranker

0 commit comments

Comments
 (0)