@@ -451,26 +451,35 @@ def test_arima_plus_score(
451
451
id_col_name ,
452
452
):
453
453
if id_col_name :
454
- result = time_series_arima_plus_model_w_id .score (
455
- new_time_series_df_w_id [["parsed_date" ]],
456
- new_time_series_df_w_id [["total_visits" ]],
457
- new_time_series_df_w_id [["id" ]],
458
- ).to_pandas ()
454
+ result = (
455
+ time_series_arima_plus_model_w_id .score (
456
+ new_time_series_df_w_id [["parsed_date" ]],
457
+ new_time_series_df_w_id [["total_visits" ]],
458
+ new_time_series_df_w_id [["id" ]],
459
+ )
460
+ .to_pandas ()
461
+ .sort_values ("id" )
462
+ .reset_index ()
463
+ )
459
464
else :
460
465
result = time_series_arima_plus_model .score (
461
466
new_time_series_df [["parsed_date" ]], new_time_series_df [["total_visits" ]]
462
467
).to_pandas ()
463
468
if id_col_name :
464
- expected = pd .DataFrame (
465
- {
466
- "id" : ["2" , "1" ],
467
- "mean_absolute_error" : [120.011007 , 120.011007 ],
468
- "mean_squared_error" : [14562.562359 , 14562.562359 ],
469
- "root_mean_squared_error" : [120.675442 , 120.675442 ],
470
- "mean_absolute_percentage_error" : [4.80044 , 4.80044 ],
471
- "symmetric_mean_absolute_percentage_error" : [4.744332 , 4.744332 ],
472
- },
473
- dtype = "Float64" ,
469
+ expected = (
470
+ pd .DataFrame (
471
+ {
472
+ "id" : ["2" , "1" ],
473
+ "mean_absolute_error" : [120.011007 , 120.011007 ],
474
+ "mean_squared_error" : [14562.562359 , 14562.562359 ],
475
+ "root_mean_squared_error" : [120.675442 , 120.675442 ],
476
+ "mean_absolute_percentage_error" : [4.80044 , 4.80044 ],
477
+ "symmetric_mean_absolute_percentage_error" : [4.744332 , 4.744332 ],
478
+ },
479
+ dtype = "Float64" ,
480
+ )
481
+ .sort_values ("id" )
482
+ .reset_index ()
474
483
)
475
484
expected ["id" ] = expected ["id" ].astype (str ).str .replace (r"\.0$" , "" , regex = True )
476
485
expected ["id" ] = expected ["id" ].astype ("string[pyarrow]" )
@@ -486,8 +495,8 @@ def test_arima_plus_score(
486
495
dtype = "Float64" ,
487
496
)
488
497
pd .testing .assert_frame_equal (
489
- result . sort_values ( "id" ). reset_index () ,
490
- expected . sort_values ( "id" ). reset_index () ,
498
+ result ,
499
+ expected ,
491
500
rtol = 0.1 ,
492
501
check_index_type = False ,
493
502
check_dtype = False ,
@@ -545,26 +554,35 @@ def test_arima_plus_score_series(
545
554
id_col_name ,
546
555
):
547
556
if id_col_name :
548
- result = time_series_arima_plus_model_w_id .score (
549
- new_time_series_df_w_id ["parsed_date" ],
550
- new_time_series_df_w_id ["total_visits" ],
551
- new_time_series_df_w_id ["id" ],
552
- ).to_pandas ()
557
+ result = (
558
+ time_series_arima_plus_model_w_id .score (
559
+ new_time_series_df_w_id ["parsed_date" ],
560
+ new_time_series_df_w_id ["total_visits" ],
561
+ new_time_series_df_w_id ["id" ],
562
+ )
563
+ .to_pandas ()
564
+ .sort_values ("id" )
565
+ .reset_index ()
566
+ )
553
567
else :
554
568
result = time_series_arima_plus_model .score (
555
569
new_time_series_df ["parsed_date" ], new_time_series_df ["total_visits" ]
556
570
).to_pandas ()
557
571
if id_col_name :
558
- expected = pd .DataFrame (
559
- {
560
- "id" : ["2" , "1" ],
561
- "mean_absolute_error" : [120.011007 , 120.011007 ],
562
- "mean_squared_error" : [14562.562359 , 14562.562359 ],
563
- "root_mean_squared_error" : [120.675442 , 120.675442 ],
564
- "mean_absolute_percentage_error" : [4.80044 , 4.80044 ],
565
- "symmetric_mean_absolute_percentage_error" : [4.744332 , 4.744332 ],
566
- },
567
- dtype = "Float64" ,
572
+ expected = (
573
+ pd .DataFrame (
574
+ {
575
+ "id" : ["2" , "1" ],
576
+ "mean_absolute_error" : [120.011007 , 120.011007 ],
577
+ "mean_squared_error" : [14562.562359 , 14562.562359 ],
578
+ "root_mean_squared_error" : [120.675442 , 120.675442 ],
579
+ "mean_absolute_percentage_error" : [4.80044 , 4.80044 ],
580
+ "symmetric_mean_absolute_percentage_error" : [4.744332 , 4.744332 ],
581
+ },
582
+ dtype = "Float64" ,
583
+ )
584
+ .sort_values ("id" )
585
+ .reset_index ()
568
586
)
569
587
expected ["id" ] = expected ["id" ].astype (str ).str .replace (r"\.0$" , "" , regex = True )
570
588
expected ["id" ] = expected ["id" ].astype ("string[pyarrow]" )
@@ -580,11 +598,10 @@ def test_arima_plus_score_series(
580
598
dtype = "Float64" ,
581
599
)
582
600
pd .testing .assert_frame_equal (
583
- result . sort_values ( "id" ). reset_index () ,
584
- expected . sort_values ( "id" ). reset_index () ,
601
+ result ,
602
+ expected ,
585
603
rtol = 0.1 ,
586
604
check_index_type = False ,
587
- check_dtype = False ,
588
605
)
589
606
590
607
0 commit comments