@@ -58,14 +58,25 @@ def splitter(
5858 train , dev , test = simple_split (data , pct_train , pct_dev , pct_test )
5959
6060 # Final assertions for time series
61- window = tss .get ('window' , 1 ) if tss .get ('window' , 1 ) else 1
62- horizon = tss .get ('horizon' , 1 ) if tss .get ('horizon' , 1 ) else 1
63-
64- if min (len (train ), len (dev )) < window :
65- raise Exception (f"Dataset size is too small for the specified window size ({ window } )" )
66-
67- if min (len (train ), len (dev ), len (test )) < horizon :
68- raise Exception (f"Dataset size is too small for the specified horizon size ({ horizon } )" )
61+ if tss .get ('is_timeseries' , False ) not in (None , False ):
62+ window = tss .get ('window' , 1 ) if tss .get ('window' , 1 ) else 1
63+ horizon = tss .get ('horizon' , 1 ) if tss .get ('horizon' , 1 ) else 1
64+
65+ if all ([pct_train , pct_dev , pct_test ]) > 0.0 :
66+ check_partitions = [train , dev , test ]
67+ elif all ([pct_train , pct_test ]) > 0.0 :
68+ check_partitions = [train , test ]
69+ elif all ([pct_train , pct_dev ]) > 0.0 :
70+ check_partitions = [train , dev ]
71+ else :
72+ check_partitions = [train ]
73+ partition_lengths = [len (partition ) for partition in check_partitions ]
74+
75+ if min (partition_lengths ) < window :
76+ raise Exception (f"Dataset too small for the specified window size ({ window } ). Partition length: { partition_lengths } " ) # noqa
77+
78+ if min (partition_lengths ) < horizon :
79+ raise Exception (f"Dataset too small for the specified horizon size ({ horizon } ). Partition length: { partition_lengths } " ) # noqa
6980
7081 return {"train" : train , "test" : test , "dev" : dev , "stratified_on" : stratify_on }
7182
0 commit comments