Skip to content

Commit c1f9966

Browse files
committed
WL#17088 MySQL Connector Python HeatWave ML SDK
Issue: MySQL AI services are difficult to integrate with python data processing pipelines. Solution: Provide an easy, intuitive API for performing ml queries inside of mysql databases through mysql ai. ML packages are designed for ease of use with pandas data types Change-Id: I41f2c297815f8199f90bc69368ad07fff66a3f43
1 parent 16ba058 commit c1f9966

File tree

14 files changed

+3113
-0
lines changed

14 files changed

+3113
-0
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) 2025 Oracle and/or its affiliates.
2+
#
3+
# This program is free software; you can redistribute it and/or modify
4+
# it under the terms of the GNU General Public License, version 2.0, as
5+
# published by the Free Software Foundation.
6+
#
7+
# This program is designed to work with certain software (including
8+
# but not limited to OpenSSL) that is licensed under separate terms,
9+
# as designated in a particular file or component or in included license
10+
# documentation. The authors of MySQL hereby grant you an
11+
# additional permission to link the program and your derivative works
12+
# with the separately licensed software that they have either included with
13+
# the program or referenced in the documentation.
14+
#
15+
# Without limiting anything contained in the foregoing, this file,
16+
# which is part of MySQL Connector/Python, is also subject to the
17+
# Universal FOSS Exception, version 1.0, a copy of which can be found at
18+
# http://oss.oracle.com/licenses/universal-foss-exception.
19+
#
20+
# This program is distributed in the hope that it will be useful, but
21+
# WITHOUT ANY WARRANTY; without even the implied warranty of
22+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
23+
# See the GNU General Public License, version 2.0, for more details.
24+
#
25+
# You should have received a copy of the GNU General Public License
26+
# along with this program; if not, write to the Free Software Foundation, Inc.,
27+
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
28+
29+
from mysql.ai.utils import check_dependencies as _check_dependencies
30+
31+
_check_dependencies(["ML"])
32+
del _check_dependencies
33+
34+
# Sklearn models
35+
from .classifier import MyClassifier
36+
37+
# Minimal interface
38+
from .model import ML_TASK, MyModel
39+
from .outlier import MyAnomalyDetector
40+
from .regressor import MyRegressor
41+
from .transformer import MyGenericTransformer
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright (c) 2025 Oracle and/or its affiliates.
2+
#
3+
# This program is free software; you can redistribute it and/or modify
4+
# it under the terms of the GNU General Public License, version 2.0, as
5+
# published by the Free Software Foundation.
6+
#
7+
# This program is designed to work with certain software (including
8+
# but not limited to OpenSSL) that is licensed under separate terms,
9+
# as designated in a particular file or component or in included license
10+
# documentation. The authors of MySQL hereby grant you an
11+
# additional permission to link the program and your derivative works
12+
# with the separately licensed software that they have either included with
13+
# the program or referenced in the documentation.
14+
#
15+
# Without limiting anything contained in the foregoing, this file,
16+
# which is part of MySQL Connector/Python, is also subject to the
17+
# Universal FOSS Exception, version 1.0, a copy of which can be found at
18+
# http://oss.oracle.com/licenses/universal-foss-exception.
19+
#
20+
# This program is distributed in the hope that it will be useful, but
21+
# WITHOUT ANY WARRANTY; without even the implied warranty of
22+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
23+
# See the GNU General Public License, version 2.0, for more details.
24+
#
25+
# You should have received a copy of the GNU General Public License
26+
# along with this program; if not, write to the Free Software Foundation, Inc.,
27+
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
28+
29+
from typing import Optional, Union
30+
31+
import pandas as pd
32+
33+
from mysql.ai.ml.model import ML_TASK, MyModel
34+
from mysql.ai.utils import copy_dict
35+
from sklearn.base import BaseEstimator
36+
37+
from mysql.connector.abstracts import MySQLConnectionAbstract
38+
39+
40+
class MyBaseMLModel(BaseEstimator):
41+
"""
42+
Base class for MySQL HeatWave machine learning estimators.
43+
44+
Implements the scikit-learn API and core model management logic,
45+
including fit, explain, serialization, and dynamic option handling.
46+
For use as a base class by classifiers, regressors, transformers, and outlier models.
47+
48+
Args:
49+
db_connection (MySQLConnectionAbstract): An active MySQL connector database connection.
50+
task (str): ML task type, e.g. "classification" or "regression".
51+
model_name (str, optional): Custom name for the deployed model.
52+
fit_extra_options (dict, optional): Extra options for fitting.
53+
54+
Attributes:
55+
_model: Underlying database helper for fit/predict/explain.
56+
fit_extra_options: User-provided options for fitting.
57+
"""
58+
59+
def __init__(
60+
self,
61+
db_connection: MySQLConnectionAbstract,
62+
task: Union[str, ML_TASK],
63+
model_name: Optional[str] = None,
64+
fit_extra_options: Optional[dict] = None,
65+
):
66+
"""
67+
Initialize a MyBaseMLModel with connection, task, and option parameters.
68+
69+
Args:
70+
db_connection: Active MySQL connector database connection.
71+
task: String label of ML task (e.g. "classification").
72+
model_name: Optional custom model name.
73+
fit_extra_options: Optional extra fit options.
74+
75+
Raises:
76+
DatabaseError:
77+
If a database connection issue occurs.
78+
If an operational error occurs during execution.
79+
"""
80+
self._model = MyModel(db_connection, task=task, model_name=model_name)
81+
self.fit_extra_options = copy_dict(fit_extra_options)
82+
83+
def fit(self, X: pd.DataFrame, y: Optional[pd.DataFrame] = None) -> "MyBaseMLModel":
84+
"""
85+
Fit the underlying ML model using pandas DataFrames.
86+
Delegates to MyMLModelPandasHelper.fit.
87+
88+
Args:
89+
X: Features DataFrame.
90+
y: (Optional) Target labels DataFrame or Series.
91+
92+
Returns:
93+
self
94+
95+
Raises:
96+
DatabaseError:
97+
If provided options are invalid or unsupported.
98+
If a database connection issue occurs.
99+
If an operational error occurs during execution.
100+
101+
Notes:
102+
Additional temp SQL resources may be created and cleaned up during the operation.
103+
"""
104+
self._model.fit(X, y, self.fit_extra_options)
105+
return self
106+
107+
def _delete_model(self) -> bool:
108+
"""
109+
Deletes the model from the model catalog if present
110+
111+
Raises:
112+
DatabaseError:
113+
If a database connection issue occurs.
114+
If an operational error occurs during execution.
115+
116+
Returns:
117+
Whether the model was deleted
118+
"""
119+
return self._model._delete_model()
120+
121+
def get_model_info(self) -> Optional[dict]:
122+
"""
123+
Checks if the model name is available. Model info will only be present in the catalog if the model has previously been fitted.
124+
125+
Returns:
126+
True if the model name is not part of the model catalog
127+
128+
Raises:
129+
DatabaseError:
130+
If a database connection issue occurs.
131+
If an operational error occurs during execution.
132+
"""
133+
return self._model.get_model_info()
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright (c) 2025 Oracle and/or its affiliates.
2+
#
3+
# This program is free software; you can redistribute it and/or modify
4+
# it under the terms of the GNU General Public License, version 2.0, as
5+
# published by the Free Software Foundation.
6+
#
7+
# This program is designed to work with certain software (including
8+
# but not limited to OpenSSL) that is licensed under separate terms,
9+
# as designated in a particular file or component or in included license
10+
# documentation. The authors of MySQL hereby grant you an
11+
# additional permission to link the program and your derivative works
12+
# with the separately licensed software that they have either included with
13+
# the program or referenced in the documentation.
14+
#
15+
# Without limiting anything contained in the foregoing, this file,
16+
# which is part of MySQL Connector/Python, is also subject to the
17+
# Universal FOSS Exception, version 1.0, a copy of which can be found at
18+
# http://oss.oracle.com/licenses/universal-foss-exception.
19+
#
20+
# This program is distributed in the hope that it will be useful, but
21+
# WITHOUT ANY WARRANTY; without even the implied warranty of
22+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
23+
# See the GNU General Public License, version 2.0, for more details.
24+
#
25+
# You should have received a copy of the GNU General Public License
26+
# along with this program; if not, write to the Free Software Foundation, Inc.,
27+
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
28+
29+
from typing import Optional, Union
30+
31+
import numpy as np
32+
import pandas as pd
33+
34+
from mysql.ai.ml.base import MyBaseMLModel
35+
from mysql.ai.ml.model import ML_TASK
36+
from mysql.ai.utils import copy_dict
37+
from sklearn.base import ClassifierMixin
38+
39+
from mysql.connector.abstracts import MySQLConnectionAbstract
40+
41+
42+
class MyClassifier(MyBaseMLModel, ClassifierMixin):
43+
"""
44+
MySQL HeatWave scikit-learn compatible classifier estimator.
45+
46+
Provides prediction and probability output from a model deployed in MySQL,
47+
and manages fit, explain, and prediction options as per HeatWave ML interface.
48+
49+
Attributes:
50+
predict_extra_options (dict): Dictionary of optional parameters passed through
51+
to the MySQL backend for prediction and probability inference.
52+
_model (MyModel): Underlying interface for database model operations.
53+
fit_extra_options (dict): See MyBaseMLModel.
54+
55+
Args:
56+
db_connection (MySQLConnectionAbstract): Active MySQL connector DB connection.
57+
model_name (str, optional): Custom name for the model.
58+
fit_extra_options (dict, optional): Extra options for fitting.
59+
explain_extra_options (dict, optional): Extra options for explanations.
60+
predict_extra_options (dict, optional): Extra options for predict/predict_proba.
61+
62+
Methods:
63+
predict(X): Predict class labels.
64+
predict_proba(X): Predict class probabilities.
65+
"""
66+
67+
def __init__(
68+
self,
69+
db_connection: MySQLConnectionAbstract,
70+
model_name: Optional[str] = None,
71+
fit_extra_options: Optional[dict] = None,
72+
explain_extra_options: Optional[dict] = None,
73+
predict_extra_options: Optional[dict] = None,
74+
):
75+
"""
76+
Initialize a MyClassifier.
77+
78+
Args:
79+
db_connection: Active MySQL connector database connection.
80+
model_name: Optional, custom model name.
81+
fit_extra_options: Optional fit options.
82+
explain_extra_options: Optional explain options.
83+
predict_extra_options: Optional predict/predict_proba options.
84+
85+
Raises:
86+
DatabaseError:
87+
If a database connection issue occurs.
88+
If an operational error occurs during execution.
89+
"""
90+
MyBaseMLModel.__init__(
91+
self,
92+
db_connection,
93+
ML_TASK.CLASSIFICATION,
94+
model_name=model_name,
95+
fit_extra_options=fit_extra_options,
96+
)
97+
self.predict_extra_options = copy_dict(predict_extra_options)
98+
self.explain_extra_options = copy_dict(explain_extra_options)
99+
100+
def predict(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:
101+
"""
102+
Predict class labels for the input features using the MySQL model.
103+
104+
References:
105+
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
106+
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
107+
108+
Args:
109+
X: Input samples as a numpy array or pandas DataFrame.
110+
111+
Returns:
112+
ndarray: Array of predicted class labels, shape (n_samples,).
113+
114+
Raises:
115+
DatabaseError:
116+
If provided options are invalid or unsupported, or if the model is not initialized, i.e., fit or import has not been called
117+
If a database connection issue occurs.
118+
If an operational error occurs during execution.
119+
"""
120+
result = self._model.predict(X, options=self.predict_extra_options)
121+
return result["Prediction"].to_numpy()
122+
123+
def predict_proba(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:
124+
"""
125+
Predict class probabilities for the input features using the MySQL model.
126+
127+
References:
128+
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-predict-table.html
129+
A full list of supported options can be found under "ML_PREDICT_TABLE Options"
130+
131+
Args:
132+
X: Input samples as a numpy array or pandas DataFrame.
133+
134+
Returns:
135+
ndarray: Array of shape (n_samples, n_classes) with class probabilities.
136+
137+
Raises:
138+
DatabaseError:
139+
If provided options are invalid or unsupported, or if the model is not initialized, i.e., fit or import has not been called
140+
If a database connection issue occurs.
141+
If an operational error occurs during execution.
142+
"""
143+
result = self._model.predict(X, options=self.predict_extra_options)
144+
145+
classes = sorted(result["ml_results"].iloc[0]["probabilities"].keys())
146+
147+
return np.stack(
148+
result["ml_results"].map(
149+
lambda ml_result: [
150+
ml_result["probabilities"][class_name] for class_name in classes
151+
]
152+
)
153+
)
154+
155+
def explain_predictions(self, X: Union[pd.DataFrame, np.ndarray]) -> pd.DataFrame:
156+
"""
157+
Explain model predictions using provided data.
158+
159+
References:
160+
https://dev.mysql.com/doc/heatwave/en/mys-hwaml-ml-explain-table.html
161+
A full list of supported options can be found under "ML_EXPLAIN_TABLE Options"
162+
163+
Args:
164+
X: DataFrame for which predictions should be explained.
165+
166+
Returns:
167+
DataFrame containing explanation details (feature attributions, etc.)
168+
169+
Raises:
170+
DatabaseError:
171+
If provided options are invalid or unsupported, or if the model is not initialized, i.e., fit or import has not been called
172+
If a database connection issue occurs.
173+
If an operational error occurs during execution.
174+
175+
Notes:
176+
Temporary input/output tables are cleaned up after explanation.
177+
"""
178+
self._model.explain_predictions(X, options=self.explain_extra_options)

0 commit comments

Comments
 (0)