Skip to content

Commit 5870948

Browse files
committed
add a parameter to the dataset that (if set), throws out all instances that have more than x tokens
1 parent 052677e commit 5870948

File tree

1 file changed

+7
-2
lines changed
  • chebai/preprocessing/datasets

1 file changed

+7
-2
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
inner_k_folds: int = -1, # use inner cross-validation if > 1
8080
fold_index: Optional[int] = None,
8181
base_dir: Optional[str] = None,
82+
n_token_limit: Optional[int] = None,
8283
**kwargs,
8384
):
8485
super().__init__()
@@ -110,6 +111,7 @@ def __init__(
110111
), "fold_index can't be larger than the total number of folds"
111112
self.fold_index = fold_index
112113
self._base_dir = base_dir
114+
self.n_token_limit = n_token_limit
113115
os.makedirs(self.raw_dir, exist_ok=True)
114116
os.makedirs(self.processed_dir, exist_ok=True)
115117
if self.use_inner_cross_validation:
@@ -311,8 +313,9 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]:
311313
for d in tqdm.tqdm(self._load_dict(path), total=lines)
312314
if d["features"] is not None
313315
]
314-
# filter for missing features in resulting data
315-
data = [val for val in data if val["features"] is not None]
316+
# filter for missing features in resulting data, keep features length below token limit
317+
data = [val for val in data if val["features"] is not None
318+
and self.n_token_limit is None or len(val["features"]) <= self.n_token_limit]
316319

317320
return data
318321

@@ -1181,4 +1184,6 @@ def processed_file_names_dict(self) -> dict:
11811184
dict: A dictionary mapping dataset keys to their respective file names.
11821185
For example, {"data": "data.pt"}.
11831186
"""
1187+
if self.n_token_limit is not None:
1188+
return {"data": f"data_maxlen{self.n_token_limit}.pt"}
11841189
return {"data": "data.pt"}

0 commit comments

Comments
 (0)