@@ -79,6 +79,7 @@ def __init__(
7979 inner_k_folds : int = - 1 , # use inner cross-validation if > 1
8080 fold_index : Optional [int ] = None ,
8181 base_dir : Optional [str ] = None ,
82+ n_token_limit : Optional [int ] = None ,
8283 ** kwargs ,
8384 ):
8485 super ().__init__ ()
@@ -110,6 +111,7 @@ def __init__(
110111 ), "fold_index can't be larger than the total number of folds"
111112 self .fold_index = fold_index
112113 self ._base_dir = base_dir
114+ self .n_token_limit = n_token_limit
113115 os .makedirs (self .raw_dir , exist_ok = True )
114116 os .makedirs (self .processed_dir , exist_ok = True )
115117 if self .use_inner_cross_validation :
@@ -311,8 +313,9 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]:
311313 for d in tqdm .tqdm (self ._load_dict (path ), total = lines )
312314 if d ["features" ] is not None
313315 ]
314- # filter for missing features in resulting data
315- data = [val for val in data if val ["features" ] is not None ]
316+ # filter for missing features in resulting data, keep features length below token limit
317+ data = [val for val in data if val ["features" ] is not None
318+ and self .n_token_limit is None or len (val ["features" ]) <= self .n_token_limit ]
316319
317320 return data
318321
@@ -1181,4 +1184,6 @@ def processed_file_names_dict(self) -> dict:
11811184 dict: A dictionary mapping dataset keys to their respective file names.
11821185 For example, {"data": "data.pt"}.
11831186 """
1187+ if self .n_token_limit is not None :
1188+ return {"data" : f"data_maxlen{ self .n_token_limit } .pt" }
11841189 return {"data" : "data.pt" }
0 commit comments