|
8 | 8 | from numbers import Number |
9 | 9 |
|
10 | 10 | import dask.array as da |
| 11 | +from dask.array import Array |
11 | 12 | import numpy as np |
12 | 13 | from xarray import DataArray |
13 | 14 |
|
@@ -293,25 +294,32 @@ def rmsse( |
293 | 294 | fwd = WRMSSE._fwd_mean_squared_diff(h, ref, pre) |
294 | 295 | bwd = WRMSSE._bwd_mean_squared_diff(b, ref) |
295 | 296 | return da.sqrt( |
296 | | - fwd[b + 1 + h // 2 : fwd.shape[0] - h // 2] |
| 297 | + fwd[b + h // 2 : fwd.shape[0] - h // 2] |
297 | 298 | / bwd[b // 2 : bwd.shape[0] - b // 2 - h] |
298 | 299 | ) |
299 | 300 |
|
300 | 301 | @staticmethod |
301 | 302 | def _bwd_mean_squared_diff(b: int, ref: DataArray) -> DataArray: |
| 303 | + def ker(x: Array, **kwargs) -> DataArray: |
| 304 | + """Returns the mean squared difference.""" |
| 305 | + return da.mean(da.square(da.diff(x)), axis=-1) |
| 306 | + |
302 | 307 | return ( |
303 | | - da.square(ref.diff(DID_TIM)) |
304 | | - .rolling({DID_TIM: b}, min_periods=b) |
305 | | - .mean() |
| 308 | + ref.rolling({DID_TIM: b}, min_periods=b, center=True) |
| 309 | + .reduce(ker) |
306 | 310 | .drop_vars(DID_TIM) |
307 | 311 | ) |
308 | 312 |
|
309 | 313 | @staticmethod |
310 | 314 | def _fwd_mean_squared_diff( |
311 | 315 | h: int, ref: DataArray, pre: DataArray |
312 | 316 | ) -> DataArray: |
| 317 | + def ker(x: Array, **kwargs) -> DataArray: |
| 318 | + """Returns the mean value.""" |
| 319 | + return da.mean(x, axis=-1) |
| 320 | + |
313 | 321 | return ( |
314 | 322 | da.square(ref - pre.data) |
315 | | - .rolling({DID_TIM: h}, min_periods=h) |
316 | | - .mean() |
| 323 | + .rolling({DID_TIM: h}, min_periods=h, center=True) |
| 324 | + .reduce(ker) |
317 | 325 | ) |
0 commit comments