Skip to content

Commit 757d629

Browse files
authored
Add/cross decomposition (#150)
* add cross decomposition chain * add cross decomposition models to `pymilo_param.py` * add cross decomposition support to `get_concrete_transporte` function * add test case for `PLSRegression` * add test case for `PLSCanonical` * add test case for `CCA` * add test runner for cross decomposition models * `SUPPORTED_MODELS.md` updated * `CHANGELOG.md` updated * `README.md` updated
1 parent fa61aa7 commit 757d629

File tree

10 files changed

+260
-3
lines changed

10 files changed

+260
-3
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,19 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
66

77
## [Unreleased]
88
### Added
9+
- `CCA` model
10+
- `PLSCanonical` model
11+
- `PLSRegression` model
12+
- Cross decomposition models test runner
13+
- Cross decomposition chain
914
- PyMilo exception types added in `pymilo/exceptions/__init__.py`
1015
- PyMilo exception types added in `pymilo/__init__.py`
1116
### Changed
17+
- Tests config modified
18+
- Cross decomposition params initialized in `pymilo_param`
19+
- Cross decomposition support added to `pymilo_func.py`
20+
- `SUPPORTED_MODELS.md` updated
21+
- `README.md` updated
1222
- GitHub actions are limited to the `dev` and `main` branches
1323
- `Python 3.13` added to `test.yml`
1424
## [1.0] - 2024-09-16

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ Now that you've synced the remote model with your local model, you can run funct
214214
| Ensemble Models ✅ | - |
215215
| Pipeline Model ✅ | - |
216216
| Preprocessing Models ✅ | - |
217+
| Cross Decomposition Models ✅ | - |
218+
217219

218220
Details are available in [Supported Models](https://github.com/openscilab/pymilo/blob/main/SUPPORTED_MODELS.md).
219221

SUPPORTED_MODELS.md

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Supported Models
22

3-
**Last Update: 2024-05-30**
3+
**Last Update: 2024-10-31**
44

55

66
<h2 id="scikit-learn">Scikit-Learn</h2>
@@ -707,3 +707,29 @@
707707
</tr>
708708

709709
</table>
710+
711+
712+
<h3 id="scikit-learn-cross-decomposition">Cross Decomposition Modules</h3>
713+
📚 <a href="https://scikit-learn.org/stable/api/sklearn.cross_decomposition.html" target="_blank"><b>Models Document</b></a>
714+
<table>
715+
<tr align="center">
716+
<th>ID</th>
717+
<th>Model Name</th>
718+
<th>PyMilo Version</th>
719+
</tr>
720+
<tr align="center">
721+
<td>1</td>
722+
<td><b>PLSRegression</b></td>
723+
<td>>=1.1</td>
724+
</tr>
725+
<tr align="center">
726+
<td>2</td>
727+
<td><b>PLSCanonical</b></td>
728+
<td>>=1.1</td>
729+
</tr>
730+
<tr align="center">
731+
<td>3</td>
732+
<td><b>CCA</b></td>
733+
<td>>=1.1</td>
734+
</tr>
735+
</table>
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# -*- coding: utf-8 -*-
2+
"""PyMilo chain for cross decomposition models."""
3+
from ..transporters.transporter import Command
4+
5+
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
6+
from ..transporters.preprocessing_transporter import PreprocessingTransporter
7+
8+
from ..pymilo_param import SKLEARN_CROSS_DECOMPOSITION_TABLE
9+
from ..exceptions.serialize_exception import PymiloSerializationException, SerializationErrorTypes
10+
from ..exceptions.deserialize_exception import PymiloDeserializationException, DeserializationErrorTypes
11+
12+
from ..utils.util import get_sklearn_type
13+
14+
from traceback import format_exc
15+
16+
CROSS_DECOMPOSITION_CHAIN = {
17+
"PreprocessingTransporter": PreprocessingTransporter(),
18+
"GeneralDataStructureTransporter": GeneralDataStructureTransporter(),
19+
}
20+
21+
22+
def is_cross_decomposition(model):
23+
"""
24+
Check if the input model is a sklearn's cross decomposition model.
25+
26+
:param model: is a string name of a cross decomposition or a sklearn object of it
27+
:type model: any object
28+
:return: check result as bool
29+
"""
30+
if isinstance(model, str):
31+
return model in SKLEARN_CROSS_DECOMPOSITION_TABLE
32+
else:
33+
return get_sklearn_type(model) in SKLEARN_CROSS_DECOMPOSITION_TABLE.keys()
34+
35+
36+
def transport_cross_decomposition(request, command, is_inner_model=False):
37+
"""
38+
Return the transported (Serialized or Deserialized) model.
39+
40+
:param request: given cross decomposition model to be transported
41+
:type request: any object
42+
:param command: command to specify whether the request should be serialized or deserialized
43+
:type command: transporter.Command
44+
:param is_inner_model: determines whether it is an inner model of a super ml model
45+
:type is_inner_model: boolean
46+
:return: the transported request as a json string or sklearn cross decomposition model
47+
"""
48+
if not is_inner_model:
49+
_validate_input(request, command)
50+
51+
if command == Command.SERIALIZE:
52+
try:
53+
return serialize_cross_decomposition(request)
54+
except Exception as e:
55+
raise PymiloSerializationException(
56+
{
57+
'error_type': SerializationErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE,
58+
'error': {
59+
'Exception': repr(e),
60+
'Traceback': format_exc(),
61+
},
62+
'object': request,
63+
})
64+
65+
elif command == Command.DESERIALIZE:
66+
try:
67+
return deserialize_cross_decomposition(request, is_inner_model)
68+
except Exception as e:
69+
raise PymiloDeserializationException(
70+
{
71+
'error_type': SerializationErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE,
72+
'error': {
73+
'Exception': repr(e),
74+
'Traceback': format_exc()},
75+
'object': request})
76+
77+
78+
def serialize_cross_decomposition(cross_decomposition_object):
79+
"""
80+
Return the serialized json string of the given cross decomposition model.
81+
82+
:param cross_decomposition_object: given model to be get serialized
83+
:type cross_decomposition_object: any sklearn cross decomposition model
84+
:return: the serialized json string of the given cross decomposition model
85+
"""
86+
for transporter in CROSS_DECOMPOSITION_CHAIN:
87+
CROSS_DECOMPOSITION_CHAIN[transporter].transport(
88+
cross_decomposition_object, Command.SERIALIZE)
89+
return cross_decomposition_object.__dict__
90+
91+
92+
def deserialize_cross_decomposition(cross_decomposition, is_inner_model=False):
93+
"""
94+
Return the associated sklearn cross decomposition model.
95+
96+
:param cross_decomposition: given json string of a cross decomposition model to get deserialized to associated sklearn cross decomposition model
97+
:type cross_decomposition: obj
98+
:param is_inner_model: determines whether it is an inner linear model of a super ml model
99+
:type is_inner_model: boolean
100+
:return: associated sklearn cross decomposition model
101+
"""
102+
raw_model = None
103+
data = None
104+
if is_inner_model:
105+
raw_model = SKLEARN_CROSS_DECOMPOSITION_TABLE[cross_decomposition["type"]]()
106+
data = cross_decomposition["data"]
107+
else:
108+
raw_model = SKLEARN_CROSS_DECOMPOSITION_TABLE[cross_decomposition.type]()
109+
data = cross_decomposition.data
110+
111+
for transporter in CROSS_DECOMPOSITION_CHAIN:
112+
CROSS_DECOMPOSITION_CHAIN[transporter].transport(
113+
cross_decomposition, Command.DESERIALIZE, is_inner_model)
114+
for item in data:
115+
setattr(raw_model, item, data[item])
116+
return raw_model
117+
118+
119+
def _validate_input(model, command):
120+
"""
121+
Check if the provided inputs are valid in relation to each other.
122+
123+
:param model: a sklearn cross decomposition model or a json string of it, serialized through the pymilo export.
124+
:type model: obj
125+
:param command: command to specify whether the request should be serialized or deserialized
126+
:type command: transporter.Command
127+
:return: None
128+
"""
129+
if command == Command.SERIALIZE:
130+
if is_cross_decomposition(model):
131+
return
132+
else:
133+
raise PymiloSerializationException(
134+
{
135+
'error_type': SerializationErrorTypes.INVALID_MODEL,
136+
'object': model
137+
}
138+
)
139+
elif command == Command.DESERIALIZE:
140+
if is_cross_decomposition(model.type):
141+
return
142+
else:
143+
raise PymiloDeserializationException(
144+
{
145+
'error_type': DeserializationErrorTypes.INVALID_MODEL,
146+
'object': model
147+
}
148+
)

pymilo/chains/util.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .naive_bayes_chain import transport_naive_bayes, is_naive_bayes
88
from .svm_chain import transport_svm, is_svm
99
from .neighbours_chain import transport_neighbor, is_neighbors
10-
10+
from .cross_decomposition_chain import transport_cross_decomposition, is_cross_decomposition
1111

1212
MODEL_TYPE_TRANSPORTER = {
1313
"LINEAR_MODEL": transport_linear_model,
@@ -16,7 +16,8 @@
1616
"CLUSTERING": transport_clusterer,
1717
"NAIVE_BAYES": transport_naive_bayes,
1818
"SVM": transport_svm,
19-
"NEIGHBORS": transport_neighbor
19+
"NEIGHBORS": transport_neighbor,
20+
"CROSS_DECOMPOSITION": transport_cross_decomposition,
2021
}
2122

2223

@@ -47,5 +48,7 @@ def get_concrete_transporter(model):
4748
return "SVM", transport_svm
4849
elif is_neighbors(model):
4950
return "NEIGHBORS", transport_neighbor
51+
elif is_cross_decomposition(model):
52+
return "CROSS_DECOMPOSITION", transport_cross_decomposition
5053
else:
5154
return None, None

pymilo/pymilo_param.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sklearn.ensemble as ensemble
1414
import sklearn.pipeline as pipeline
1515
import sklearn.preprocessing as preprocessing
16+
from sklearn.cross_decomposition import PLSRegression, PLSCanonical, CCA
1617

1718
quantile_regressor_support = False
1819
try:
@@ -236,6 +237,12 @@
236237
"TargetEncoder": TargetEncoder if target_encoder_support else NOT_SUPPORTED,
237238
}
238239

240+
SKLEARN_CROSS_DECOMPOSITION_TABLE = {
241+
"PLSRegression": PLSRegression,
242+
"PLSCanonical": PLSCanonical,
243+
"CCA": CCA,
244+
}
245+
239246
KEYS_NEED_PREPROCESSING_BEFORE_DESERIALIZATION = {
240247
"_label_binarizer": preprocessing.LabelBinarizer, # in Ridge Classifier
241248
"active_": np.int32, # in Lasso Lars
@@ -267,4 +274,5 @@
267274
"SVM": "exported_svms",
268275
"NEIGHBORS": "exported_neighbors",
269276
"ENSEMBLE": "exported_ensembles",
277+
"CROSS_DECOMPOSITION": "exported_cross_decomposition",
270278
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from sklearn.cross_decomposition import CCA
2+
from pymilo.utils.test_pymilo import pymilo_regression_test
3+
from pymilo.utils.data_exporter import prepare_simple_regression_datasets
4+
5+
MODEL_NAME = "CCA"
6+
7+
def cca():
8+
x_train, y_train, x_test, y_test = prepare_simple_regression_datasets()
9+
cca = CCA(n_components=1).fit(x_train, y_train)
10+
pymilo_regression_test(cca, MODEL_NAME, (x_test, y_test))
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from sklearn.cross_decomposition import PLSCanonical
2+
from pymilo.utils.test_pymilo import pymilo_regression_test
3+
from pymilo.utils.data_exporter import prepare_simple_regression_datasets
4+
5+
MODEL_NAME = "PLSCanonical"
6+
7+
def pls_canonical():
8+
x_train, y_train, x_test, y_test = prepare_simple_regression_datasets()
9+
pls_canonical = PLSCanonical(n_components=1).fit(x_train, y_train)
10+
pymilo_regression_test(pls_canonical, MODEL_NAME, (x_test, y_test))
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from sklearn.cross_decomposition import PLSRegression
2+
from pymilo.utils.test_pymilo import pymilo_regression_test
3+
from pymilo.utils.data_exporter import prepare_simple_regression_datasets
4+
5+
MODEL_NAME = "PLSRegression"
6+
7+
def pls_regressor():
8+
x_train, y_train, x_test, y_test = prepare_simple_regression_datasets()
9+
pls_regressor = PLSRegression(n_components=2).fit(x_train, y_train)
10+
pymilo_regression_test(pls_regressor, MODEL_NAME, (x_test, y_test))
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import os
2+
import pytest
3+
4+
from pls_regression import pls_regressor
5+
from pls_canonical import pls_canonical
6+
from cca import cca
7+
8+
CROSS_DECOMPOSITIONS = {
9+
"PLS_REGRESSION": [pls_regressor],
10+
"PLS_CANONICAL": [pls_canonical],
11+
"CCA": [cca],
12+
}
13+
14+
@pytest.fixture(scope="session", autouse=True)
15+
def reset_exported_models_directory():
16+
exported_models_directory = os.path.join(
17+
os.getcwd(), "tests", "exported_cross_decomposition")
18+
if not os.path.isdir(exported_models_directory):
19+
os.mkdir(exported_models_directory)
20+
return
21+
for file_name in os.listdir(exported_models_directory):
22+
# construct full file path
23+
json_file = os.path.join(exported_models_directory, file_name)
24+
if os.path.isfile(json_file):
25+
os.remove(json_file)
26+
27+
def test_full():
28+
for category in CROSS_DECOMPOSITIONS:
29+
for model in CROSS_DECOMPOSITIONS[category]:
30+
model()

0 commit comments

Comments
 (0)