Skip to content

Commit 6d3cea7

Browse files
committed
Update: RMSLE metric
1 parent e5fe4f7 commit 6d3cea7

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

wqf/val/metrics.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def rer(
5858
) -> DataArray:
5959
"""Returns the relative error."""
6060
ref, pre = _select(ref, pre, condition)
61-
return Bias.err(pre, ref) / ref
61+
return Bias.err(ref, pre) / ref
6262

6363

6464
class Count(Metric):
@@ -212,6 +212,29 @@ def se(ref: DataArray, pre: DataArray) -> DataArray:
212212
return da.square(ref - pre.data)
213213

214214

215+
class RMSLE(Metric):
216+
"""
217+
The root mean squared logarithmic error (RMSLE).
218+
219+
The RMSLE asymptotically approaches the RMSE for values smaller
220+
than unity and the MAPD for values larger than unity.
221+
"""
222+
223+
def value(self, ref: DataArray, pre: DataArray, **kwargs) -> Number:
224+
return da.sqrt(RMSLE.sle(ref, pre).mean()).values.item()
225+
226+
def image(self, ref: DataArray, pre: DataArray, **kwargs) -> DataArray:
227+
return da.sqrt(RMSLE.sle(ref, pre).mean(DID_TIM))
228+
229+
def series(self, ref: DataArray, pre: DataArray, **kwargs) -> DataArray:
230+
return da.sqrt(RMSLE.sle(ref, pre).mean([DID_LAT, DID_LON]))
231+
232+
@staticmethod
233+
def sle(ref: DataArray, pre: DataArray) -> DataArray:
234+
"""Returns the squared logarithmic error."""
235+
return da.square(da.log1p(ref) - da.log1p(pre.data))
236+
237+
215238
class WRMSSE(Metric):
216239
"""
217240
The weighted root mean squared scaled error (WRMSSE).

0 commit comments

Comments
 (0)