22
33from random import Random
44
5- from sklearn .model_selection import train_test_split
6-
75from autointent import Dataset
86from autointent .custom_types import Split
97
@@ -32,19 +30,17 @@ def _sample_intent_regexp(
3230def sample_from_regex (
3331 in_dataset : Dataset ,
3432 n_shots : int ,
33+ split_name : str = Split .TRAIN ,
3534 n_rep_limit : int = 20 ,
36- val_size : float = 0.2 ,
37- test_size : float = 0.2 ,
3835 random_seed : int | None = None ,
3936) -> Dataset :
4037 """
4138 Generate utterances from dataset with regular expressions.
4239
4340 :param in_dataset: The dataset containing intents with regular exressions.
4441 :param n_shots: The maximum number of samples to produce for every intent.
42+ :param split_name: Where to put the data.
4543 :param n_rep_limit: To limit the number of possible repetitions in a regular expression.
46- :param val_size: The proportion to be allocated for the validation part.
47- :param test_size: The proportion to be allocated for the test part.
4844 :param random_seed: To make your sampling deterministic.
4945
5046 :returns: The dataset with sampled utterances.
@@ -53,23 +49,12 @@ def sample_from_regex(
5349 intents = in_dataset .intents
5450
5551 splits : dict [str , list ] = { # type: ignore[type-arg]
56- Split .TRAIN : [],
57- Split .VALIDATION : [],
58- Split .TEST : [],
52+ split_name : []
5953 }
6054
6155 for intent in intents :
6256 utterances = _sample_intent_regexp (intent .regexp_full_match , n_shots , n_rep_limit , intent .id , rng )
63-
64- x_train , x_remaining = train_test_split (utterances , test_size = val_size + test_size , random_state = random_seed )
65- splits [Split .TRAIN ].extend (x_train )
66-
67- x_val , x_test = train_test_split (
68- x_remaining , test_size = test_size / (test_size + val_size ), random_state = random_seed
69- )
70-
71- splits [Split .VALIDATION ].extend (x_val )
72- splits [Split .TEST ].extend (x_test )
57+ splits [split_name ].extend (utterances )
7358
7459 splits ["intents" ] = intents
7560
0 commit comments