diff --git a/onnxmltools/convert/lightgbm/_parse.py b/onnxmltools/convert/lightgbm/_parse.py index b31710b3..8f1b7dc3 100644 --- a/onnxmltools/convert/lightgbm/_parse.py +++ b/onnxmltools/convert/lightgbm/_parse.py @@ -11,7 +11,7 @@ Int64Type, ) -from lightgbm import LGBMClassifier, LGBMRegressor +from lightgbm import LGBMClassifier, LGBMRegressor, LGBMRanker lightgbm_classifier_list = [LGBMClassifier] @@ -21,6 +21,7 @@ lightgbm_operator_name_map = { LGBMClassifier: "LgbmClassifier", LGBMRegressor: "LgbmRegressor", + LGBMRanker: "LgbmRanker", } @@ -35,6 +36,8 @@ def __init__(self, booster): elif self.objective_.startswith("multiclass"): self.operator_name = "LgbmClassifier" self.classes_ = self._generate_classes(booster) + elif self.objective_.startswith("lambdarank"): + self.operator_name = "LgbmRanker" elif self.objective_.startswith( ("regression", "poisson", "gamma", "quantile", "huber", "tweedie") ): diff --git a/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py b/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py index 43a1b5eb..d0c47a1e 100644 --- a/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py +++ b/onnxmltools/convert/lightgbm/operator_converters/LightGbm.py @@ -566,6 +566,10 @@ def convert_lightgbm(scope, operator, container): # so we need to add an 'Exp' post transform node to the model attrs["post_transform"] = "NONE" post_transform = "Exp" + elif gbm_text["objective"].startswith("lambdarank"): + n_classes = 1 # Ranker has only one output variable + attrs["post_transform"] = "NONE" + attrs["n_targets"] = n_classes else: raise RuntimeError( "LightGBM objective should be cleaned already not '{}'.".format( @@ -1026,3 +1030,4 @@ def convert_lgbm_zipmap(scope, operator, container): register_converter("LgbmClassifier", convert_lightgbm) register_converter("LgbmRegressor", convert_lightgbm) register_converter("LgbmZipMap", convert_lgbm_zipmap) +register_converter("LgbmRanker", convert_lightgbm) diff --git a/onnxmltools/convert/lightgbm/shape_calculators/Ranker.py b/onnxmltools/convert/lightgbm/shape_calculators/Ranker.py new file mode 100644 index 00000000..67eaf146 --- /dev/null +++ b/onnxmltools/convert/lightgbm/shape_calculators/Ranker.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 + +from ...common._registration import register_shape_calculator +from ...common.shape_calculator import calculate_linear_regressor_output_shapes + +register_shape_calculator("LgbmRanker", calculate_linear_regressor_output_shapes) diff --git a/onnxmltools/convert/lightgbm/shape_calculators/__init__.py b/onnxmltools/convert/lightgbm/shape_calculators/__init__.py index e7a2c3d9..18880b10 100644 --- a/onnxmltools/convert/lightgbm/shape_calculators/__init__.py +++ b/onnxmltools/convert/lightgbm/shape_calculators/__init__.py @@ -3,3 +3,4 @@ # To register shape calculators for lightgbm operators, import associated modules here. from . import Classifier from . import Regressor +from . import Ranker