diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 817bc1d1..39e5fbec 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -79,6 +79,7 @@ def __init__( inner_k_folds: int = -1, # use inner cross-validation if > 1 fold_index: Optional[int] = None, base_dir: Optional[str] = None, + n_token_limit: Optional[int] = None, **kwargs, ): super().__init__() @@ -110,6 +111,7 @@ def __init__( ), "fold_index can't be larger than the total number of folds" self.fold_index = fold_index self._base_dir = base_dir + self.n_token_limit = n_token_limit os.makedirs(self.raw_dir, exist_ok=True) os.makedirs(self.processed_dir, exist_ok=True) if self.use_inner_cross_validation: @@ -311,8 +313,14 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: for d in tqdm.tqdm(self._load_dict(path), total=lines) if d["features"] is not None ] - # filter for missing features in resulting data - data = [val for val in data if val["features"] is not None] + # filter for missing features in resulting data, keep features length below token limit + data = [ + val + for val in data + if val["features"] is not None + and self.n_token_limit is None + or len(val["features"]) <= self.n_token_limit + ] return data @@ -1181,4 +1189,6 @@ def processed_file_names_dict(self) -> dict: dict: A dictionary mapping dataset keys to their respective file names. For example, {"data": "data.pt"}. """ + if self.n_token_limit is not None: + return {"data": f"data_maxlen{self.n_token_limit}.pt"} return {"data": "data.pt"}