Skip to content

Commit d7a5fa9

Browse files
committed
added logic to create two way split (e.g. train & other)
1 parent 8cd99d2 commit d7a5fa9

File tree

1 file changed

+335
-1
lines changed

1 file changed

+335
-1
lines changed

coderdata/dataset/dataset.py

Lines changed: 335 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
from sklearn.model_selection import StratifiedShuffleSplit
2626

2727

28+
@dataclass
29+
class TwoWaySplit:
30+
train: Dataset
31+
other: Dataset
32+
2833
@dataclass
2934
class Split:
3035
train: Dataset
@@ -330,7 +335,28 @@ def format(
330335
):
331336
return format(self, data_type=data_type, use_polars=use_polars, **kwargs)
332337

333-
338+
339+
def split_train_other(
340+
self,
341+
split_type: Literal[
342+
'mixed-set', 'drug-blind', 'cancer-blind'
343+
]='mixed-set',
344+
ratio: tuple[int, int, int]=(8,2),
345+
stratify_by: Optional[str]=None,
346+
random_state: Optional[Union[int,RandomState]]=None,
347+
**kwargs: dict,
348+
) -> TwoWaySplit:
349+
350+
split = split_train_other(
351+
data=self,
352+
split_type=split_type,
353+
ration=ratio,
354+
stratify_by=stratify_by,
355+
random_state=random_state,
356+
**kwargs
357+
)
358+
359+
return split
334360
def train_test_validate(
335361
self,
336362
split_type: Literal[
@@ -665,6 +691,30 @@ def format(
665691

666692
return ret
667693

694+
695+
696+
def split_train_other(
697+
data: Dataset,
698+
split_type: Literal[
699+
'mixed-set', 'drug-blind', 'cancer-blind'
700+
]='mixed-set',
701+
ratio: tuple[int, int, int]=(8,2),
702+
stratify_by: Optional[str]=None,
703+
random_state: Optional[Union[int,RandomState]]=None,
704+
**kwargs: dict,
705+
):
706+
train, other = _split_two_way(
707+
data,
708+
split_type,
709+
ratio,
710+
stratify_by,
711+
random_state,
712+
kwargs=kwargs
713+
)
714+
if stratify_by is not None:
715+
train.experiments = train.experiments[train.experiments['dose_response_metric'] != 'split_class']
716+
other.experiments = other.experiments[other.experiments['dose_response_metric'] != 'split_class']
717+
return TwoWaySplit(train=train, other=other)
668718
def train_test_validate(
669719
data: Dataset,
670720
split_type: Literal[
@@ -1194,3 +1244,287 @@ def _create_classes(
11941244
)
11951245

11961246
return data
1247+
1248+
1249+
def _split_two_way(
1250+
data: Dataset,
1251+
split_type: Literal[
1252+
'mixed-set', 'drug-blind', 'cancer-blind'
1253+
]='mixed-set',
1254+
ratio: tuple[int, int, int]=(8,2),
1255+
stratify_by: Optional[str]=None,
1256+
random_state: Optional[Union[int,RandomState]]=None,
1257+
**kwargs: dict,
1258+
) -> tuple[Dataset, Dataset]:
1259+
"""
1260+
Splits a `CoderData` object (see also
1261+
`coderdata.load.loader.DatasetLoader`) into three subsets for
1262+
training, testing and validating machine learning algorithms.
1263+
1264+
The size of the splits can be adjusted to be different from 80:10:10
1265+
(the default)for train:test:validate. The function also allows for
1266+
additional optional arguments, that define the type of split that is
1267+
performed ('mixed-set', 'drug-blind', 'cancer-blind'), if the splits
1268+
should be stratified (and which drug response metric to use), as
1269+
well as a random seed to enable the creation of reproducable splits.
1270+
Furhermore, a list of keyword arguments can be defined that will be
1271+
passed to the stratification function if so desired.
1272+
1273+
Parameters
1274+
----------
1275+
data : DatasetLoader
1276+
CoderData object containing a full dataset either downloaded
1277+
from the CoderData repository (see also
1278+
`coderdata.download.downloader.download_data_by_prefix`) or
1279+
built locally via the `build_all` process. The object must first
1280+
be loaded via `coderdata.load.loader.DatasetLoader`.
1281+
split_type : {'mixed-set', 'drug-blind', 'cancer-blind'}, \
1282+
default='mixed-set'
1283+
1284+
Defines the type of split that should be generated:
1285+
1286+
- *mixed-set*: Splits randomly independent of drug / cancer
1287+
association of the samples. Individual drugs or cancer types
1288+
can appear in all three splits
1289+
- *drug-blind*: Splits according to drug association. Any sample
1290+
associated with a drug will be unique to one of the splits.
1291+
For example samples with association to drug A will only be
1292+
present in the train split, but never in test or validate.
1293+
- *cancer-blind*: Splits according to cancer association.
1294+
Equivalent to drug-blind, except cancer types will be unique
1295+
to splits.
1296+
ratio : tuple[int, int, int], default=(8,1,1)
1297+
Defines the size ratio of the resulting test, train and
1298+
validation sets.
1299+
stratify_by : str | None, default=None
1300+
Defines if the training, testing and validation sets should be
1301+
stratified. Any value other than None indicates stratification
1302+
and defines which drug response value should be used as basis
1303+
for the stratification. _None_ indicates that no stratfication
1304+
should be performed.
1305+
random_state : int | RandomState | None, defaul=None
1306+
Defines a seed value for the randomization of the splits. Will
1307+
get passed to internal functions. Providing the seed will enable
1308+
reproducability of the generated splits.
1309+
**kwargs
1310+
Additional keyword arguments that will be passed to the function
1311+
that generates classes for the stratification
1312+
(see also ``_create_classes``).
1313+
1314+
Returns
1315+
-------
1316+
Splits : Split
1317+
A ``Split`` object that contains three Dataset objects as
1318+
attributes (``Split.train``, ``Split.test``,
1319+
``Split.validate``)
1320+
1321+
Raises
1322+
-------
1323+
ValueError :
1324+
If supplied `split_type` is not in the list of accepted values.
1325+
1326+
"""
1327+
1328+
# reading in the potential keyword arguments that will be passed to
1329+
# _create_classes().
1330+
thresh = kwargs.get('thresh', None)
1331+
num_classes = kwargs.get('num_classes', 2)
1332+
quantiles = kwargs.get('quantiles', True)
1333+
1334+
# Type checking split_type
1335+
if split_type not in [
1336+
'mixed-set', 'drug-blind', 'cancer-blind'
1337+
]:
1338+
raise ValueError(
1339+
f"{split_type} not an excepted input for 'split_type'"
1340+
)
1341+
1342+
# A wide (pivoted) table is more easy to work with in this instance.
1343+
# The pivot is done using all columns but the 'dose_respones_value'
1344+
# and 'dose_respones_metric' as index. df.pivot will generate a
1345+
# MultiIndex which complicates things further down the line. To that
1346+
# end 'reset_index()' is used to remove the MultiIndex
1347+
df_full = data.experiments.copy()
1348+
df_full = df_full.pivot(
1349+
index = [
1350+
'source',
1351+
'improve_sample_id',
1352+
'improve_drug_id',
1353+
'study',
1354+
'time',
1355+
'time_unit'
1356+
],
1357+
columns = 'dose_response_metric',
1358+
values = 'dose_response_value'
1359+
).reset_index()
1360+
1361+
# Defining the split sizes.
1362+
train_size = float(ratio[0]) / sum(ratio)
1363+
test_val_size = float(ratio[1]) / sum(ratio)
1364+
1365+
# ShuffleSplit is a method/class implemented by scikit-learn that
1366+
# enables creating splits where the data is shuffled and then
1367+
# randomly distributed into train and test sets according to the
1368+
# defined ratio.
1369+
#
1370+
# n_splits defines how often a train/test split is generated.
1371+
# Individual splits (if more than 1 is generated) are not guaranteed
1372+
# to be disjoint i.e. test sets from individual splits can overlap.
1373+
#
1374+
# ShuffleSplit will be used for non stratified mixed-set splitting
1375+
# since there is no requirement for disjoint groups (i.e. drug /
1376+
# sample ids).
1377+
shs = ShuffleSplit(
1378+
n_splits=1,
1379+
train_size=train_size,
1380+
test_size=test_val_size,
1381+
random_state=random_state
1382+
)
1383+
1384+
# GroupShuffleSplit is an extension to ShuffleSplit that also
1385+
# factors in a group that is used to generate disjoint train and
1386+
# test sets, e.g. in this particular case the drug or sample id to
1387+
# generate drug-blind or sample-blind train and test sets.
1388+
#
1389+
# GroupShuffleSplit will be used for non stratified drug-/sample-
1390+
# blind splitting, i.e. there is a requirement that instances from
1391+
# one group (e.g. a specific drug) are only present in the training
1392+
# set but not in the test set.
1393+
gss = GroupShuffleSplit(
1394+
n_splits=1,
1395+
train_size=train_size,
1396+
test_size=test_val_size,
1397+
random_state=random_state
1398+
)
1399+
1400+
# StratifiedShuffleSplit is similar to ShuffleSplit with the added
1401+
# functionality to also stratify the splits according to defined
1402+
# class labels.
1403+
#
1404+
# StratifiedShuffleSplit will be used for stratified mixed-set
1405+
# train/test/validate sets.
1406+
1407+
sss = StratifiedShuffleSplit(
1408+
n_splits=1,
1409+
train_size=train_size,
1410+
test_size=test_val_size,
1411+
random_state=random_state
1412+
)
1413+
1414+
# StratifiedGroupKFold generates K folds that take the group into
1415+
# account when generating folds, i.e. a group will only be present
1416+
# in one fold. It further tries to stratify the folds based on the
1417+
# defined classes.
1418+
#
1419+
# StratifiedGroupKFold will be used for stratified drug-/sample-
1420+
# blind splitting.
1421+
#
1422+
# The way the K folds are utilized is to combine i, j, & k folds
1423+
# (according to the defined ratio) into training, testing and
1424+
# validation sets.
1425+
sgk = StratifiedGroupKFold(
1426+
n_splits=sum(ratio),
1427+
shuffle=True,
1428+
random_state=random_state
1429+
)
1430+
1431+
# The "actual" splitting logic using the defined Splitters as above
1432+
# follows here starting with the non-stratified splitting:
1433+
if stratify_by is None:
1434+
if split_type == 'mixed-set':
1435+
# Using ShuffleSplit to generate randomized train and
1436+
# 'other' set, since there is no need for grouping.
1437+
idx1, idx2 = next(
1438+
shs.split(df_full)
1439+
)
1440+
elif split_type == 'drug-blind':
1441+
# Using GroupShuffleSplit to created disjoint train and
1442+
# 'other' sets by drug id
1443+
idx1, idx2 = next(
1444+
gss.split(df_full, groups=df_full.improve_drug_id)
1445+
)
1446+
elif split_type == 'cancer-blind':
1447+
# same as above we just group over the sample id
1448+
idx1, idx2 = next(
1449+
gss.split(df_full, groups=df_full.improve_sample_id)
1450+
)
1451+
else:
1452+
raise Exception(f"Should be unreachable")
1453+
1454+
# generate new DFs containing the subset of items extracted for
1455+
# train and other
1456+
df_train = df_full.iloc[idx1]
1457+
df_other = df_full.iloc[idx2]
1458+
1459+
1460+
# The following block contains the stratified splitting logic
1461+
else:
1462+
# First the classes that are needed for the stratification are
1463+
# generated. `num_classes`, `thresh` and `quantiles` were
1464+
# previously defined as possible keyword arguments.
1465+
if 'split_class' not in df_full.columns.to_list():
1466+
df_full = _create_classes(
1467+
data=df_full,
1468+
metric=stratify_by,
1469+
num_classes=num_classes,
1470+
thresh=thresh,
1471+
quantiles=quantiles,
1472+
)
1473+
if split_type == 'mixed-set':
1474+
# Using StratifiedShuffleSplit to generate randomized train
1475+
# and 'other' set, since there is no need for grouping.
1476+
idx_train, idx_other = next(
1477+
sss.split(X=df_full, y=df_full['split_class'])
1478+
)
1479+
df_train = df_full.iloc[idx_train]
1480+
# df_train = df_train.drop(labels=['split_class'], axis=1)
1481+
df_other = df_full.iloc[idx_other]
1482+
1483+
# using StratifiedGroupKSplit for the stratified drug-/sample-
1484+
# blind splits.
1485+
elif split_type == 'drug-blind' or split_type == 'cancer-blind':
1486+
if split_type == 'drug-blind':
1487+
splitter = enumerate(
1488+
sgk.split(
1489+
X=df_full,
1490+
y=df_full['split_class'],
1491+
groups=df_full.improve_drug_id
1492+
)
1493+
)
1494+
elif split_type == 'cancer-blind':
1495+
splitter = enumerate(
1496+
sgk.split(
1497+
X=df_full,
1498+
y=df_full['split_class'],
1499+
groups=df_full.improve_sample_id
1500+
)
1501+
)
1502+
1503+
# StratifiedGroupKSplit is setup to generate K splits where
1504+
# K=sum(ratios) (e.g. 10 if ratio=8:1:1). To obtain three
1505+
# sets (train/test/validate) the individual splits need to
1506+
# be combined (e.g. k=[1:8] -> train, k=9 -> test, k=10 ->
1507+
# validate). The code block below does that by combining
1508+
# all indices (row numbers) that go into individual sets and
1509+
# then extracting and adding those rows into the individual
1510+
# sets.
1511+
idx_train = []
1512+
idx_other = []
1513+
for i, (idx1, idx2) in splitter:
1514+
if i < ratio[0]:
1515+
idx_train.extend(idx2)
1516+
elif i >= ratio[0]:
1517+
idx_other.extend(idx2)
1518+
# df_full.drop(labels=['split_class'], axis=1, inplace=True)
1519+
df_train = df_full.iloc[idx_train]
1520+
df_other = df_full.iloc[idx_other]
1521+
else:
1522+
raise Exception(f"Should be unreachable")
1523+
1524+
1525+
# generating filtered CoderData objects that contain only the
1526+
# respective data for each split
1527+
data_train = _filter(data, df_train)
1528+
data_other = _filter(data, df_other)
1529+
1530+
return (data_train, data_other)

0 commit comments

Comments
 (0)