Skip to content

Commit a08be2b

Browse files
committed
Changes according to the review request
1 parent 84dfd05 commit a08be2b

File tree

2 files changed

+6
-37
lines changed

2 files changed

+6
-37
lines changed

autointent/generation/regex_generation.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from random import Random
44

5-
from sklearn.model_selection import train_test_split
6-
75
from autointent import Dataset
86
from autointent.custom_types import Split
97

@@ -32,19 +30,17 @@ def _sample_intent_regexp(
3230
def sample_from_regex(
3331
in_dataset: Dataset,
3432
n_shots: int,
33+
split_name: str = Split.TRAIN,
3534
n_rep_limit: int = 20,
36-
val_size: float = 0.2,
37-
test_size: float = 0.2,
3835
random_seed: int | None = None,
3936
) -> Dataset:
4037
"""
4138
Generate utterances from dataset with regular expressions.
4239
4340
:param in_dataset: The dataset containing intents with regular exressions.
4441
:param n_shots: The maximum number of samples to produce for every intent.
42+
:param split_name: Where to put the data.
4543
:param n_rep_limit: To limit the number of possible repetitions in a regular expression.
46-
:param val_size: The proportion to be allocated for the validation part.
47-
:param test_size: The proportion to be allocated for the test part.
4844
:param random_seed: To make your sampling deterministic.
4945
5046
:returns: The dataset with sampled utterances.
@@ -53,23 +49,12 @@ def sample_from_regex(
5349
intents = in_dataset.intents
5450

5551
splits: dict[str, list] = { # type: ignore[type-arg]
56-
Split.TRAIN: [],
57-
Split.VALIDATION: [],
58-
Split.TEST: [],
52+
split_name: []
5953
}
6054

6155
for intent in intents:
6256
utterances = _sample_intent_regexp(intent.regexp_full_match, n_shots, n_rep_limit, intent.id, rng)
63-
64-
x_train, x_remaining = train_test_split(utterances, test_size=val_size + test_size, random_state=random_seed)
65-
splits[Split.TRAIN].extend(x_train)
66-
67-
x_val, x_test = train_test_split(
68-
x_remaining, test_size=test_size / (test_size + val_size), random_state=random_seed
69-
)
70-
71-
splits[Split.VALIDATION].extend(x_val)
72-
splits[Split.TEST].extend(x_test)
57+
splits[split_name].extend(utterances)
7358

7459
splits["intents"] = intents
7560

tests/generation/test_regex_generation.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,13 @@ def in_dataset():
1616
def test_generation_basic(in_dataset):
1717
result = sample_from_regex(in_dataset, n_shots=3)
1818

19-
assert len(result[Split.TRAIN]) == 3
20-
assert len(result[Split.VALIDATION]) == 3
21-
assert len(result[Split.TEST]) == 3
19+
assert len(result[Split.TRAIN]) == 9
2220

2321

2422
def test_generation_all_samples(in_dataset):
2523
result = sample_from_regex(in_dataset, n_shots=1000)
2624

27-
assert len(result[Split.TRAIN]) == 1273
28-
assert len(result[Split.VALIDATION]) == 424
29-
assert len(result[Split.TEST]) == 425
25+
assert len(result[Split.TRAIN]) == 2122
3026

3127

3228
def test_generation_deterministic(in_dataset):
@@ -36,22 +32,10 @@ def test_generation_deterministic(in_dataset):
3632
assert len(result1[Split.TRAIN]) != 0
3733
assert result1[Split.TRAIN][Dataset.utterance_feature] == result2[Split.TRAIN][Dataset.utterance_feature]
3834

39-
assert len(result1[Split.VALIDATION]) != 0
40-
assert result1[Split.VALIDATION][Dataset.utterance_feature] == result2[Split.VALIDATION][Dataset.utterance_feature]
41-
42-
assert len(result1[Split.TEST]) != 0
43-
assert result1[Split.TEST][Dataset.utterance_feature] == result2[Split.TEST][Dataset.utterance_feature]
44-
4535

4636
def test_generation_deterministic_different_seed(in_dataset):
4737
result1 = sample_from_regex(in_dataset, n_shots=3, random_seed=42)
4838
result2 = sample_from_regex(in_dataset, n_shots=3, random_seed=40)
4939

5040
assert len(result1[Split.TRAIN]) != 0
5141
assert result1[Split.TRAIN][Dataset.utterance_feature] != result2[Split.TRAIN][Dataset.utterance_feature]
52-
53-
assert len(result1[Split.VALIDATION]) != 0
54-
assert result1[Split.VALIDATION][Dataset.utterance_feature] != result2[Split.VALIDATION][Dataset.utterance_feature]
55-
56-
assert len(result1[Split.TEST]) != 0
57-
assert result1[Split.TEST][Dataset.utterance_feature] != result2[Split.TEST][Dataset.utterance_feature]

0 commit comments

Comments
 (0)