Skip to content

Commit 2ca9e9e

Browse files
committed
try again
1 parent 5ab21a9 commit 2ca9e9e

File tree

1 file changed

+52
-35
lines changed

1 file changed

+52
-35
lines changed

tests/system/small/ml/test_forecasting.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -451,26 +451,35 @@ def test_arima_plus_score(
451451
id_col_name,
452452
):
453453
if id_col_name:
454-
result = time_series_arima_plus_model_w_id.score(
455-
new_time_series_df_w_id[["parsed_date"]],
456-
new_time_series_df_w_id[["total_visits"]],
457-
new_time_series_df_w_id[["id"]],
458-
).to_pandas()
454+
result = (
455+
time_series_arima_plus_model_w_id.score(
456+
new_time_series_df_w_id[["parsed_date"]],
457+
new_time_series_df_w_id[["total_visits"]],
458+
new_time_series_df_w_id[["id"]],
459+
)
460+
.to_pandas()
461+
.sort_values("id")
462+
.reset_index()
463+
)
459464
else:
460465
result = time_series_arima_plus_model.score(
461466
new_time_series_df[["parsed_date"]], new_time_series_df[["total_visits"]]
462467
).to_pandas()
463468
if id_col_name:
464-
expected = pd.DataFrame(
465-
{
466-
"id": ["2", "1"],
467-
"mean_absolute_error": [120.011007, 120.011007],
468-
"mean_squared_error": [14562.562359, 14562.562359],
469-
"root_mean_squared_error": [120.675442, 120.675442],
470-
"mean_absolute_percentage_error": [4.80044, 4.80044],
471-
"symmetric_mean_absolute_percentage_error": [4.744332, 4.744332],
472-
},
473-
dtype="Float64",
469+
expected = (
470+
pd.DataFrame(
471+
{
472+
"id": ["2", "1"],
473+
"mean_absolute_error": [120.011007, 120.011007],
474+
"mean_squared_error": [14562.562359, 14562.562359],
475+
"root_mean_squared_error": [120.675442, 120.675442],
476+
"mean_absolute_percentage_error": [4.80044, 4.80044],
477+
"symmetric_mean_absolute_percentage_error": [4.744332, 4.744332],
478+
},
479+
dtype="Float64",
480+
)
481+
.sort_values("id")
482+
.reset_index()
474483
)
475484
expected["id"] = expected["id"].astype(str).str.replace(r"\.0$", "", regex=True)
476485
expected["id"] = expected["id"].astype("string[pyarrow]")
@@ -486,8 +495,8 @@ def test_arima_plus_score(
486495
dtype="Float64",
487496
)
488497
pd.testing.assert_frame_equal(
489-
result.sort_values("id").reset_index(),
490-
expected.sort_values("id").reset_index(),
498+
result,
499+
expected,
491500
rtol=0.1,
492501
check_index_type=False,
493502
check_dtype=False,
@@ -545,26 +554,35 @@ def test_arima_plus_score_series(
545554
id_col_name,
546555
):
547556
if id_col_name:
548-
result = time_series_arima_plus_model_w_id.score(
549-
new_time_series_df_w_id["parsed_date"],
550-
new_time_series_df_w_id["total_visits"],
551-
new_time_series_df_w_id["id"],
552-
).to_pandas()
557+
result = (
558+
time_series_arima_plus_model_w_id.score(
559+
new_time_series_df_w_id["parsed_date"],
560+
new_time_series_df_w_id["total_visits"],
561+
new_time_series_df_w_id["id"],
562+
)
563+
.to_pandas()
564+
.sort_values("id")
565+
.reset_index()
566+
)
553567
else:
554568
result = time_series_arima_plus_model.score(
555569
new_time_series_df["parsed_date"], new_time_series_df["total_visits"]
556570
).to_pandas()
557571
if id_col_name:
558-
expected = pd.DataFrame(
559-
{
560-
"id": ["2", "1"],
561-
"mean_absolute_error": [120.011007, 120.011007],
562-
"mean_squared_error": [14562.562359, 14562.562359],
563-
"root_mean_squared_error": [120.675442, 120.675442],
564-
"mean_absolute_percentage_error": [4.80044, 4.80044],
565-
"symmetric_mean_absolute_percentage_error": [4.744332, 4.744332],
566-
},
567-
dtype="Float64",
572+
expected = (
573+
pd.DataFrame(
574+
{
575+
"id": ["2", "1"],
576+
"mean_absolute_error": [120.011007, 120.011007],
577+
"mean_squared_error": [14562.562359, 14562.562359],
578+
"root_mean_squared_error": [120.675442, 120.675442],
579+
"mean_absolute_percentage_error": [4.80044, 4.80044],
580+
"symmetric_mean_absolute_percentage_error": [4.744332, 4.744332],
581+
},
582+
dtype="Float64",
583+
)
584+
.sort_values("id")
585+
.reset_index()
568586
)
569587
expected["id"] = expected["id"].astype(str).str.replace(r"\.0$", "", regex=True)
570588
expected["id"] = expected["id"].astype("string[pyarrow]")
@@ -580,11 +598,10 @@ def test_arima_plus_score_series(
580598
dtype="Float64",
581599
)
582600
pd.testing.assert_frame_equal(
583-
result.sort_values("id").reset_index(),
584-
expected.sort_values("id").reset_index(),
601+
result,
602+
expected,
585603
rtol=0.1,
586604
check_index_type=False,
587-
check_dtype=False,
588605
)
589606

590607

0 commit comments

Comments
 (0)