Skip to content

Commit 3fac924

Browse files
authored
[FIX] Fix spis (#190)
1 parent 52a7f60 commit 3fac924

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

tests/test_losses.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def test_spis_numerical(self):
178178
Validates scaled sum of absolute errors per series (Pandas and Polars).
179179
"""
180180
# In-sample data for scaling (mean per series)
181-
df_train = pd.DataFrame({
181+
train_df = pd.DataFrame({
182182
"unique_id": ["A", "A", "B", "B"],
183183
"y": [1, 3, 2, 6]
184184
})
@@ -200,7 +200,7 @@ def test_spis_numerical(self):
200200
# Pandas test
201201
out_pd = spis(
202202
df=df,
203-
df_train=df_train,
203+
train_df=train_df,
204204
models=["y_hat"],
205205
id_col="unique_id",
206206
target_col="y",
@@ -210,7 +210,7 @@ def test_spis_numerical(self):
210210
# Polars test
211211
out_pl = spis(
212212
df=pl.DataFrame(df),
213-
df_train=pl.DataFrame(df_train),
213+
train_df=pl.DataFrame(train_df),
214214
models=["y_hat"],
215215
id_col="unique_id",
216216
target_col="y",

utilsforecast/losses.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def pis(
292292
@_base_docstring
293293
def spis(
294294
df: DFType,
295-
df_train: DFType,
295+
train_df: DFType,
296296
models: List[str],
297297
id_col: str = "unique_id",
298298
target_col: str = "y",
@@ -304,7 +304,7 @@ def spis(
304304
yielding a scale-independent bias measure that can be aggregated across series.
305305
"""
306306
if isinstance(df, pd.DataFrame):
307-
ins_means = df_train.groupby(id_col)[target_col].mean().rename("insample_mean")
307+
ins_means = train_df.groupby(id_col)[target_col].mean().rename("insample_mean")
308308
abs_err_sum = (
309309
(df[models].sub(df[target_col], axis=0))
310310
.abs()
@@ -315,7 +315,7 @@ def spis(
315315
res.index.name = id_col
316316
return res.reset_index()
317317
else:
318-
ins_means = df_train.group_by(id_col).agg(
318+
ins_means = train_df.group_by(id_col).agg(
319319
pl.col(target_col).mean().alias("insample_mean")
320320
)
321321
abs_err = _pl_agg_expr(

0 commit comments

Comments
 (0)