Skip to content

Commit 5935b56

Browse files
ju4nv1e1r4xadupre
andauthored
Add docstring for 2 classes and 1 function (#715)
Co-authored-by: Xavier Dupré <[email protected]>
1 parent 40fd362 commit 5935b56

File tree

1 file changed

+23
-0
lines changed
  • onnxmltools/convert/xgboost/operator_converters

1 file changed

+23
-0
lines changed

onnxmltools/convert/xgboost/operator_converters/XGBoost.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414

1515

1616
class XGBConverter:
17+
"""
18+
Base class for converting XGBoost models to ONNX format.
19+
This class provides methods to validate the model, retrieve parameters,
20+
and fill in the attributes for the ONNX TreeEnsemble node.
21+
"""
1722
@staticmethod
1823
def get_xgb_params(xgb_node):
1924
"""
@@ -222,6 +227,13 @@ def fill_tree_attributes(js_xgb_node, attr_pairs, tree_weights, is_classifier):
222227

223228

224229
class XGBRegressorConverter(XGBConverter):
230+
"""
231+
Converter for XGBoost Regressor models to ONNX format.
232+
This class inherits from XGBConverter and implements the conversion
233+
logic specific to regression tasks.
234+
It handles the conversion of model parameters, tree structure,
235+
and the creation of the ONNX TreeEnsembleRegressor node.
236+
"""
225237
@staticmethod
226238
def validate(xgb_node):
227239
return XGBConverter.validate(xgb_node)
@@ -423,6 +435,17 @@ def convert(scope, operator, container):
423435

424436

425437
def convert_xgboost(scope, operator, container):
438+
"""
439+
Converts an XGBoost model (XGBClassifier or XGBRegressor) into an ONNX TreeEnsemble node.
440+
441+
Parameters:
442+
scope: Object for managing variable names in the ONNX graph.
443+
operator: Wrapper for the XGBoost model and its input/output variables.
444+
container: Object to which the ONNX nodes will be added.
445+
446+
This function dispatches the conversion to the appropriate internal converter
447+
based on whether the model is a classifier or regressor.
448+
"""
426449
xgb_node = operator.raw_operator
427450
if isinstance(xgb_node, (XGBClassifier, XGBRFClassifier)) or getattr(
428451
xgb_node, "operator_name", None

0 commit comments

Comments
 (0)