Skip to content

Commit 908367c

Browse files
janrthnasaul
andauthored
Handling of 0 dividers for nd loss (#188)
Co-authored-by: Saul <[email protected]>
1 parent 11703c4 commit 908367c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

utilsforecast/losses.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def nd(
405405
df: IntoDataFrameT,
406406
models: List[str],
407407
id_col: str = "unique_id",
408-
target_col: str = "y",
408+
target_col: str = "y"
409409
) -> IntoDataFrameT:
410410
"""Normalized Deviation (ND)
411411
@@ -427,7 +427,7 @@ def nd(
427427

428428
def gen_expr(model):
429429
return ((nw.col(target_col) - nw.col(model)).abs()).alias(model)
430-
430+
431431
return (
432432
nw.from_native(df)
433433
.select(
@@ -437,7 +437,7 @@ def gen_expr(model):
437437
)
438438
.group_by(id_col)
439439
.agg(nw.all().sum())
440-
.select(id_col, *[nw.col(m) / nw.col("scale") for m in models])
440+
.select(id_col, *[(nw.col(m) / _zero_to_nan(nw.col("scale"))) for m in models])
441441
.sort(id_col)
442442
.to_native()
443443
)

0 commit comments

Comments
 (0)