|
14 | 14 |
|
15 | 15 |
|
16 | 16 | class XGBConverter: |
| 17 | + """ |
| 18 | + Base class for converting XGBoost models to ONNX format. |
| 19 | + This class provides methods to validate the model, retrieve parameters, |
| 20 | + and fill in the attributes for the ONNX TreeEnsemble node. |
| 21 | + """ |
17 | 22 | @staticmethod |
18 | 23 | def get_xgb_params(xgb_node): |
19 | 24 | """ |
@@ -222,6 +227,13 @@ def fill_tree_attributes(js_xgb_node, attr_pairs, tree_weights, is_classifier): |
222 | 227 |
|
223 | 228 |
|
224 | 229 | class XGBRegressorConverter(XGBConverter): |
| 230 | + """ |
| 231 | + Converter for XGBoost Regressor models to ONNX format. |
| 232 | + This class inherits from XGBConverter and implements the conversion |
| 233 | + logic specific to regression tasks. |
| 234 | + It handles the conversion of model parameters, tree structure, |
| 235 | + and the creation of the ONNX TreeEnsembleRegressor node. |
| 236 | + """ |
225 | 237 | @staticmethod |
226 | 238 | def validate(xgb_node): |
227 | 239 | return XGBConverter.validate(xgb_node) |
@@ -423,6 +435,17 @@ def convert(scope, operator, container): |
423 | 435 |
|
424 | 436 |
|
425 | 437 | def convert_xgboost(scope, operator, container): |
| 438 | + """ |
| 439 | + Converts an XGBoost model (XGBClassifier or XGBRegressor) into an ONNX TreeEnsemble node. |
| 440 | +
|
| 441 | + Parameters: |
| 442 | + scope: Object for managing variable names in the ONNX graph. |
| 443 | + operator: Wrapper for the XGBoost model and its input/output variables. |
| 444 | + container: Object to which the ONNX nodes will be added. |
| 445 | +
|
| 446 | + This function dispatches the conversion to the appropriate internal converter |
| 447 | + based on whether the model is a classifier or regressor. |
| 448 | + """ |
426 | 449 | xgb_node = operator.raw_operator |
427 | 450 | if isinstance(xgb_node, (XGBClassifier, XGBRFClassifier)) or getattr( |
428 | 451 | xgb_node, "operator_name", None |
|
0 commit comments