Skip to content

Commit 095b445

Browse files
committed
Fix: WRMSSE
1 parent 96160ed commit 095b445

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

wqf/val/metrics.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from numbers import Number
99

1010
import dask.array as da
11+
from dask.array import Array
1112
import numpy as np
1213
from xarray import DataArray
1314

@@ -293,25 +294,32 @@ def rmsse(
293294
fwd = WRMSSE._fwd_mean_squared_diff(h, ref, pre)
294295
bwd = WRMSSE._bwd_mean_squared_diff(b, ref)
295296
return da.sqrt(
296-
fwd[b + 1 + h // 2 : fwd.shape[0] - h // 2]
297+
fwd[b + h // 2 : fwd.shape[0] - h // 2]
297298
/ bwd[b // 2 : bwd.shape[0] - b // 2 - h]
298299
)
299300

300301
@staticmethod
301302
def _bwd_mean_squared_diff(b: int, ref: DataArray) -> DataArray:
303+
def ker(x: Array, **kwargs) -> DataArray:
304+
"""Returns the mean squared difference."""
305+
return da.mean(da.square(da.diff(x)), axis=-1)
306+
302307
return (
303-
da.square(ref.diff(DID_TIM))
304-
.rolling({DID_TIM: b}, min_periods=b)
305-
.mean()
308+
ref.rolling({DID_TIM: b}, min_periods=b, center=True)
309+
.reduce(ker)
306310
.drop_vars(DID_TIM)
307311
)
308312

309313
@staticmethod
310314
def _fwd_mean_squared_diff(
311315
h: int, ref: DataArray, pre: DataArray
312316
) -> DataArray:
317+
def ker(x: Array, **kwargs) -> DataArray:
318+
"""Returns the mean value."""
319+
return da.mean(x, axis=-1)
320+
313321
return (
314322
da.square(ref - pre.data)
315-
.rolling({DID_TIM: h}, min_periods=h)
316-
.mean()
323+
.rolling({DID_TIM: h}, min_periods=h, center=True)
324+
.reduce(ker)
317325
)

0 commit comments

Comments
 (0)