Skip to content

Commit c64b217

Browse files
authored
ENH: add mean_squared_log_error (#725)
* ENH: add mean_squared_log_error
1 parent ae2b57a commit c64b217

File tree

4 files changed

+38
-1
lines changed

4 files changed

+38
-1
lines changed

dask_ml/metrics/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,10 @@
44
pairwise_distances,
55
pairwise_distances_argmin_min,
66
)
7-
from .regression import mean_absolute_error, mean_squared_error, r2_score # noqa
7+
from .regression import ( # noqa
8+
mean_absolute_error,
9+
mean_squared_error,
10+
mean_squared_log_error,
11+
r2_score,
12+
)
813
from .scorer import SCORERS, check_scoring, get_scorer # noqa

dask_ml/metrics/regression.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,22 @@ def r2_score(
110110
if compute:
111111
result = result.compute()
112112
return result
113+
114+
115+
@derived_from(sklearn.metrics)
116+
def mean_squared_log_error(
117+
y_true: ArrayLike,
118+
y_pred: ArrayLike,
119+
sample_weight: Optional[ArrayLike] = None,
120+
multioutput: Optional[str] = "uniform_average",
121+
compute: bool = True,
122+
) -> ArrayLike:
123+
124+
result = mean_squared_error(
125+
np.log1p(y_true),
126+
np.log1p(y_pred),
127+
sample_weight=sample_weight,
128+
multioutput=multioutput,
129+
compute=compute,
130+
)
131+
return result

docs/source/modules/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ Regression Metrics
231231

232232
metrics.mean_absolute_error
233233
metrics.mean_squared_error
234+
metrics.mean_squared_log_error
234235
metrics.r2_score
235236

236237

tests/metrics/test_regression.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,15 @@ def test_mse_squared(squared):
4848
result = m1(a, b, squared=squared)
4949
expected = m2(a, b, squared=squared)
5050
assert abs(result - expected) < 1e-5
51+
52+
53+
def test_mean_squared_log_error():
54+
m1 = dask_ml.metrics.mean_squared_log_error
55+
m2 = sklearn.metrics.mean_squared_log_error
56+
57+
a = da.random.uniform(size=(100,), chunks=(25,))
58+
b = da.random.uniform(size=(100,), chunks=(25,))
59+
60+
result = m1(a, b)
61+
expected = m2(a, b)
62+
assert abs(result - expected) < 1e-5

0 commit comments

Comments
 (0)