Skip to content

Commit c89139a

Browse files
author
FelixAbrahamsson
committed
Add Dataset.group_split
1 parent 959258e commit c89139a

File tree

4 files changed

+147
-10
lines changed

4 files changed

+147
-10
lines changed

datastream/dataset.py

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,59 @@ def split(
256256
).items()
257257
}
258258

259+
def group_split(
260+
self,
261+
split_column: str,
262+
proportions: Dict[str, float],
263+
filepath: Optional[Union[str, Path]] = None,
264+
frozen: Optional[bool] = False,
265+
seed: Optional[int] = None,
266+
) -> Dict[str, Dataset[T]]:
267+
'''
268+
Similar to :func:`Dataset.split`, but uses a non-unique split column
269+
instead of a unique key column. This is useful for example when you
270+
have a dataset with examples that come from separate sources and you
271+
don't want to have examples from the same source in different splits.
272+
Does not support stratification.
273+
274+
>>> split_file = Path('doctest_split_dataset.json')
275+
>>> split_datasets = (
276+
... Dataset.from_dataframe(pd.DataFrame(dict(
277+
... source=np.arange(100) // 4,
278+
... number=np.random.randn(100),
279+
... )))
280+
... .group_split(
281+
... split_column='source',
282+
... proportions=dict(train=0.8, test=0.2),
283+
... filepath=split_file,
284+
... )
285+
... )
286+
>>> len(split_datasets['train'])
287+
80
288+
>>> split_file.unlink() # clean up after doctest
289+
'''
290+
if filepath is not None:
291+
filepath = Path(filepath)
292+
293+
split_dataframes = tools.group_split_dataframes
294+
if seed is not None:
295+
split_dataframes = tools.numpy_seed(seed)(split_dataframes)
296+
297+
return {
298+
split_name: Dataset(
299+
dataframe=dataframe,
300+
length=len(dataframe),
301+
functions=self.functions,
302+
)
303+
for split_name, dataframe in split_dataframes(
304+
self.dataframe,
305+
split_column,
306+
proportions,
307+
filepath,
308+
frozen,
309+
).items()
310+
}
311+
259312
def zip_index(self: Dataset[T]) -> Dataset[Tuple[T, int]]:
260313
'''
261314
Zip the output with its index. The output of the pipeline will be
@@ -343,10 +396,10 @@ def from_combine(index):
343396
def create_to_combine_mapping(datasets):
344397
cumprod_lengths = np.cumprod(list(map(len, datasets)))
345398
def to_concat(inner_indices):
346-
return inner_indices[0] + sum(
347-
[inner_index * cumprod_lengths[i]
348-
for i, inner_index in enumerate(inner_indices[1:])]
349-
)
399+
return inner_indices[0] + sum([
400+
inner_index * cumprod_lengths[i]
401+
for i, inner_index in enumerate(inner_indices[1:])
402+
])
350403
return to_concat
351404

352405
@staticmethod
@@ -548,3 +601,48 @@ def test_split_dataset():
548601
assert split_datasets1 != split_datasets3
549602
assert split_datasets3 == split_datasets4
550603
assert split_datasets3 != split_datasets5
604+
605+
606+
def test_group_split_dataset():
607+
dataset = Dataset.from_dataframe(pd.DataFrame(dict(
608+
group=np.arange(100) // 4,
609+
number=np.random.randn(100),
610+
))).map(tuple)
611+
612+
split_file = Path('test_split_dataset.json')
613+
proportions = dict(
614+
gradient=0.7,
615+
early_stopping=0.15,
616+
compare=0.15,
617+
)
618+
619+
kwargs = dict(
620+
split_column='group',
621+
proportions=proportions,
622+
filepath=split_file,
623+
)
624+
625+
split_datasets1 = dataset.group_split(**kwargs)
626+
split_datasets2 = dataset.group_split(**kwargs)
627+
split_datasets3 = dataset.group_split(
628+
split_column='group',
629+
proportions=proportions,
630+
seed=100,
631+
)
632+
split_datasets4 = dataset.group_split(
633+
split_column='group',
634+
proportions=proportions,
635+
seed=100,
636+
)
637+
split_datasets5 = dataset.group_split(
638+
split_column='group',
639+
proportions=proportions,
640+
seed=800,
641+
)
642+
643+
split_file.unlink()
644+
645+
assert split_datasets1 == split_datasets2
646+
assert split_datasets1 != split_datasets3
647+
assert split_datasets3 == split_datasets4
648+
assert split_datasets3 != split_datasets5

datastream/datastream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def zip(datastreams: List[Datastream]) -> Datastream[Tuple]:
118118
def map(
119119
self: Datastream[T], function: Callable[[T], R]
120120
) -> Datastream[R]:
121-
'''
121+
'''
122122
Creates a new Datastream with a new mapped dataset. See
123123
:func:`Dataset.map` for details.
124124
'''

datastream/tools/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@
22
from datastream.tools.starcompose import starcompose
33
from datastream.tools.repeat_map_chain import repeat_map_chain
44
from datastream.tools.numpy_seed import numpy_seed
5-
from datastream.tools.split_dataframes import split_dataframes
5+
from datastream.tools.split_dataframes import (
6+
split_dataframes, group_split_dataframes
7+
)

datastream/tools/split_dataframes.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import Tuple, Union, Dict, Optional
2+
from typing import Tuple, Dict, Optional
33
from pathlib import Path
44
import json
55
import numpy as np
@@ -27,7 +27,7 @@ def split_dataframes(
2727
'Expected sum of proportions to be 1.',
2828
f'Proportions were {tuple(proportions.values())}',
2929
]))
30-
30+
3131
if filepath is not None and filepath.exists():
3232
split = json.loads(filepath.read_text())
3333

@@ -77,12 +77,31 @@ def split_dataframes(
7777

7878
return {
7979
split_name: (
80-
dataframe[lambda df: df[key_column].isin(split[split_name])]
80+
dataframe[dataframe[key_column].isin(split[split_name])]
8181
)
8282
for split_name in proportions.keys()
8383
}
8484

8585

86+
def group_split_dataframes(
87+
dataframe: pd.DataFrame,
88+
split_column: str,
89+
proportions: Dict[str, float],
90+
filepath: Optional[Path] = None,
91+
frozen: Optional[bool] = False,
92+
):
93+
key_dataframe = pd.DataFrame(dict(key=dataframe[split_column].unique()))
94+
splits = split_dataframes(
95+
key_dataframe, 'key', proportions, filepath=filepath, frozen=frozen
96+
)
97+
return {
98+
split_name: (
99+
dataframe[dataframe[split_column].isin(split['key'])]
100+
)
101+
for split_name, split in splits.items()
102+
}
103+
104+
86105
def stratas(dataframe, stratify_column):
87106
return [
88107
dataframe[lambda df: df[stratify_column] == strata_value]
@@ -136,7 +155,7 @@ def n_target_split(keys, proportion):
136155
def selected(k, unassigned):
137156
return np.random.choice(
138157
unassigned, size=k, replace=False
139-
).tolist()
158+
).tolist()
140159

141160

142161
def mock_dataframe():
@@ -166,6 +185,24 @@ def test_standard():
166185
assert tuple(map(len, split_dataframes_.values())) == (80, 10, 10)
167186

168187

188+
def test_group_split_dataframe():
189+
dataframe = mock_dataframe().assign(group=lambda df: df['index'] // 4)
190+
split_dataframes_ = group_split_dataframes(
191+
dataframe,
192+
split_column='group',
193+
proportions=dict(
194+
train=0.8,
195+
compare=0.2,
196+
),
197+
)
198+
group_overlap = (
199+
set(split_dataframes_['train'].group)
200+
.intersection(split_dataframes_['compare'].group)
201+
)
202+
assert len(group_overlap) == 0
203+
assert tuple(map(len, split_dataframes_.values())) == (80, 20)
204+
205+
169206
def test_validate_proportions():
170207
from pytest import raises
171208

0 commit comments

Comments
 (0)