Skip to content

Commit b6aab8e

Browse files
author
Nabil Fayak
committed
lint fix
1 parent 49aa7c3 commit b6aab8e

File tree

5 files changed

+19
-7
lines changed

5 files changed

+19
-7
lines changed

checkmates/objectives/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,9 @@
1212
from checkmates.objectives.standard_metrics import RootMeanSquaredLogError
1313
from checkmates.objectives.standard_metrics import MeanSquaredLogError
1414

15-
from checkmates.objectives.binary_classification_objective import BinaryClassificationObjective
16-
from checkmates.objectives.multiclass_classification_objective import MulticlassClassificationObjective
15+
from checkmates.objectives.binary_classification_objective import (
16+
BinaryClassificationObjective,
17+
)
18+
from checkmates.objectives.multiclass_classification_objective import (
19+
MulticlassClassificationObjective,
20+
)

checkmates/objectives/binary_classification_objective.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ def validate_inputs(self, y_true, y_predicted):
8181
if len(np.unique(y_true)) > 2:
8282
raise ValueError("y_true contains more than two unique values")
8383
if len(np.unique(y_predicted)) > 2 and not self.score_needs_proba:
84-
raise ValueError("y_predicted contains more than two unique values")
84+
raise ValueError("y_predicted contains more than two unique values")

checkmates/objectives/multiclass_classification_objective.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ class MulticlassClassificationObjective(ObjectiveBase):
77
"""Base class for all multiclass classification objectives."""
88

99
problem_types = [ProblemTypes.MULTICLASS, ProblemTypes.TIME_SERIES_MULTICLASS]
10-
"""[ProblemTypes.MULTICLASS, ProblemTypes.TIME_SERIES_MULTICLASS]"""
10+
"""[ProblemTypes.MULTICLASS, ProblemTypes.TIME_SERIES_MULTICLASS]"""

checkmates/objectives/standard_metrics.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
import pandas as pd
44
from sklearn import metrics
55

6+
from checkmates.objectives.binary_classification_objective import (
7+
BinaryClassificationObjective,
8+
)
9+
from checkmates.objectives.multiclass_classification_objective import (
10+
MulticlassClassificationObjective,
11+
)
612
from checkmates.objectives.regression_objective import RegressionObjective
713
from checkmates.utils import classproperty
8-
from checkmates.objectives.binary_classification_objective import BinaryClassificationObjective
9-
from checkmates.objectives.multiclass_classification_objective import MulticlassClassificationObjective
1014

1115

1216
class LogLossBinary(BinaryClassificationObjective):
@@ -36,6 +40,7 @@ def objective_function(
3640
"""Objective function for log loss for binary classification."""
3741
return metrics.log_loss(y_true, y_predicted, sample_weight=sample_weight)
3842

43+
3944
class LogLossMulticlass(MulticlassClassificationObjective):
4045
"""Log Loss for multiclass classification.
4146
@@ -68,6 +73,7 @@ def objective_function(
6873
"""Objective function for log loss for multiclass classification."""
6974
return metrics.log_loss(y_true, y_predicted, sample_weight=sample_weight)
7075

76+
7177
class R2(RegressionObjective):
7278
"""Coefficient of determination for regression.
7379
@@ -95,6 +101,7 @@ def objective_function(
95101
"""Objective function for coefficient of determination for regression."""
96102
return metrics.r2_score(y_true, y_predicted, sample_weight=sample_weight)
97103

104+
98105
class MedianAE(RegressionObjective):
99106
"""Median absolute error for regression.
100107
@@ -127,7 +134,6 @@ def objective_function(
127134
)
128135

129136

130-
131137
class RootMeanSquaredLogError(RegressionObjective):
132138
"""Root mean squared log error for regression.
133139

checkmates/objectives/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def get_non_core_objectives():
2020
objectives.RootMeanSquaredLogError,
2121
]
2222

23+
2324
def get_all_objective_names():
2425
"""Get a list of the names of all objectives.
2526
@@ -29,6 +30,7 @@ def get_all_objective_names():
2930
all_objectives_dict = _all_objectives_dict()
3031
return list(all_objectives_dict.keys())
3132

33+
3234
def _all_objectives_dict():
3335
all_objectives = _get_subclasses(ObjectiveBase)
3436
objectives_dict = {}

0 commit comments

Comments
 (0)