Skip to content

Commit bcdc0de

Browse files
authored
Refactor logistic regression (#171)
* replace regression test with classification test * refactor code to use classifcation test same as the other classification models * remove logistic regression specific dataset preparation * `CHANGELOG.md` updated * Update CHANGELOG.md
1 parent 2397ddc commit bcdc0de

File tree

4 files changed

+10
-21
lines changed

4 files changed

+10
-21
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
55
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
66

77
## [Unreleased]
8+
### Added
89
### Changed
10+
- testcase for LogisticRegressionCV, LogisticRegression
911
- `README.md` updated
1012
- `AUTHORS.md` updated
1113
## [1.1] - 2024-11-25

pymilo/utils/data_exporter.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,3 @@ def prepare_simple_clustering_datasets():
5656
X = iris.data # Features
5757
y = iris.target # Target (labels)
5858
return X, y
59-
60-
61-
def prepare_logistic_regression_datasets(threshold=None):
62-
"""
63-
Generate a dataset for logistic regression (the iris).
64-
65-
:param threshold: threshold for train/test splitting
66-
:int threshold: int
67-
:return: splitted dataset for logistic regression
68-
"""
69-
iris_X, iris_y = datasets.load_iris(return_X_y=True)
70-
threshold = threshold if threshold else len(iris_y) // 2
71-
return _split_X_y(iris_X, iris_y, threshold)
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
from sklearn.linear_model import LogisticRegression
2-
from pymilo.utils.test_pymilo import pymilo_regression_test
3-
from pymilo.utils.data_exporter import prepare_simple_regression_datasets
2+
from pymilo.utils.test_pymilo import pymilo_classification_test
3+
from pymilo.utils.data_exporter import prepare_simple_classification_datasets
44

55
MODEL_NAME = "Logistic-Regression"
66

77

88
def logistic_regression():
9-
x_train, y_train, x_test, y_test = prepare_simple_regression_datasets()
9+
x_train, y_train, x_test, y_test = prepare_simple_classification_datasets()
1010
# Create Logistic regression object
1111
logistic_regression_random_state = 4
1212
logistic_regression = LogisticRegression(
1313
random_state=logistic_regression_random_state)
1414
# Train the model using the training sets
1515
logistic_regression.fit(x_train, y_train)
16-
assert pymilo_regression_test(
16+
assert pymilo_classification_test(
1717
logistic_regression, MODEL_NAME, (x_test, y_test)) == True
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from sklearn.linear_model import LogisticRegressionCV
2-
from pymilo.utils.test_pymilo import pymilo_regression_test
3-
from pymilo.utils.data_exporter import prepare_logistic_regression_datasets
2+
from pymilo.utils.test_pymilo import pymilo_classification_test
3+
from pymilo.utils.data_exporter import prepare_simple_classification_datasets
44

55
MODEL_NAME = "Logistic-Regression-CV"
66

77

88
def logistic_regression_cv():
9-
x_train, y_train, x_test, y_test = prepare_logistic_regression_datasets()
9+
x_train, y_train, x_test, y_test = prepare_simple_classification_datasets()
1010
# Create Logistic regression cv object
1111
logistic_regression_cv = 5
1212
logistic_regression_random_state = 0
@@ -15,5 +15,5 @@ def logistic_regression_cv():
1515
random_state=logistic_regression_random_state)
1616
# Train the model using the training sets
1717
logistic_regression_cv.fit(x_train, y_train)
18-
assert pymilo_regression_test(
18+
assert pymilo_classification_test(
1919
logistic_regression_cv, MODEL_NAME, (x_test, y_test)) == True

0 commit comments

Comments
 (0)