|
25 | 25 | from sklearn.model_selection import StratifiedShuffleSplit |
26 | 26 |
|
27 | 27 |
|
| 28 | +@dataclass |
| 29 | +class TwoWaySplit: |
| 30 | + train: Dataset |
| 31 | + other: Dataset |
| 32 | + |
28 | 33 | @dataclass |
29 | 34 | class Split: |
30 | 35 | train: Dataset |
@@ -330,7 +335,28 @@ def format( |
330 | 335 | ): |
331 | 336 | return format(self, data_type=data_type, use_polars=use_polars, **kwargs) |
332 | 337 |
|
333 | | - |
| 338 | + |
| 339 | + def split_train_other( |
| 340 | + self, |
| 341 | + split_type: Literal[ |
| 342 | + 'mixed-set', 'drug-blind', 'cancer-blind' |
| 343 | + ]='mixed-set', |
| 344 | + ratio: tuple[int, int, int]=(8,2), |
| 345 | + stratify_by: Optional[str]=None, |
| 346 | + random_state: Optional[Union[int,RandomState]]=None, |
| 347 | + **kwargs: dict, |
| 348 | + ) -> TwoWaySplit: |
| 349 | + |
| 350 | + split = split_train_other( |
| 351 | + data=self, |
| 352 | + split_type=split_type, |
| 353 | + ration=ratio, |
| 354 | + stratify_by=stratify_by, |
| 355 | + random_state=random_state, |
| 356 | + **kwargs |
| 357 | + ) |
| 358 | + |
| 359 | + return split |
334 | 360 | def train_test_validate( |
335 | 361 | self, |
336 | 362 | split_type: Literal[ |
@@ -665,6 +691,30 @@ def format( |
665 | 691 |
|
666 | 692 | return ret |
667 | 693 |
|
| 694 | + |
| 695 | + |
| 696 | +def split_train_other( |
| 697 | + data: Dataset, |
| 698 | + split_type: Literal[ |
| 699 | + 'mixed-set', 'drug-blind', 'cancer-blind' |
| 700 | + ]='mixed-set', |
| 701 | + ratio: tuple[int, int, int]=(8,2), |
| 702 | + stratify_by: Optional[str]=None, |
| 703 | + random_state: Optional[Union[int,RandomState]]=None, |
| 704 | + **kwargs: dict, |
| 705 | + ): |
| 706 | + train, other = _split_two_way( |
| 707 | + data, |
| 708 | + split_type, |
| 709 | + ratio, |
| 710 | + stratify_by, |
| 711 | + random_state, |
| 712 | + kwargs=kwargs |
| 713 | + ) |
| 714 | + if stratify_by is not None: |
| 715 | + train.experiments = train.experiments[train.experiments['dose_response_metric'] != 'split_class'] |
| 716 | + other.experiments = other.experiments[other.experiments['dose_response_metric'] != 'split_class'] |
| 717 | + return TwoWaySplit(train=train, other=other) |
668 | 718 | def train_test_validate( |
669 | 719 | data: Dataset, |
670 | 720 | split_type: Literal[ |
@@ -1194,3 +1244,287 @@ def _create_classes( |
1194 | 1244 | ) |
1195 | 1245 |
|
1196 | 1246 | return data |
| 1247 | + |
| 1248 | + |
| 1249 | +def _split_two_way( |
| 1250 | + data: Dataset, |
| 1251 | + split_type: Literal[ |
| 1252 | + 'mixed-set', 'drug-blind', 'cancer-blind' |
| 1253 | + ]='mixed-set', |
| 1254 | + ratio: tuple[int, int, int]=(8,2), |
| 1255 | + stratify_by: Optional[str]=None, |
| 1256 | + random_state: Optional[Union[int,RandomState]]=None, |
| 1257 | + **kwargs: dict, |
| 1258 | + ) -> tuple[Dataset, Dataset]: |
| 1259 | + """ |
| 1260 | + Splits a `CoderData` object (see also |
| 1261 | + `coderdata.load.loader.DatasetLoader`) into three subsets for |
| 1262 | + training, testing and validating machine learning algorithms. |
| 1263 | + |
| 1264 | + The size of the splits can be adjusted to be different from 80:10:10 |
| 1265 | + (the default)for train:test:validate. The function also allows for |
| 1266 | + additional optional arguments, that define the type of split that is |
| 1267 | + performed ('mixed-set', 'drug-blind', 'cancer-blind'), if the splits |
| 1268 | + should be stratified (and which drug response metric to use), as |
| 1269 | + well as a random seed to enable the creation of reproducable splits. |
| 1270 | + Furhermore, a list of keyword arguments can be defined that will be |
| 1271 | + passed to the stratification function if so desired. |
| 1272 | +
|
| 1273 | + Parameters |
| 1274 | + ---------- |
| 1275 | + data : DatasetLoader |
| 1276 | + CoderData object containing a full dataset either downloaded |
| 1277 | + from the CoderData repository (see also |
| 1278 | + `coderdata.download.downloader.download_data_by_prefix`) or |
| 1279 | + built locally via the `build_all` process. The object must first |
| 1280 | + be loaded via `coderdata.load.loader.DatasetLoader`. |
| 1281 | + split_type : {'mixed-set', 'drug-blind', 'cancer-blind'}, \ |
| 1282 | + default='mixed-set' |
| 1283 | +
|
| 1284 | + Defines the type of split that should be generated: |
| 1285 | + |
| 1286 | + - *mixed-set*: Splits randomly independent of drug / cancer |
| 1287 | + association of the samples. Individual drugs or cancer types |
| 1288 | + can appear in all three splits |
| 1289 | + - *drug-blind*: Splits according to drug association. Any sample |
| 1290 | + associated with a drug will be unique to one of the splits. |
| 1291 | + For example samples with association to drug A will only be |
| 1292 | + present in the train split, but never in test or validate. |
| 1293 | + - *cancer-blind*: Splits according to cancer association. |
| 1294 | + Equivalent to drug-blind, except cancer types will be unique |
| 1295 | + to splits. |
| 1296 | + ratio : tuple[int, int, int], default=(8,1,1) |
| 1297 | + Defines the size ratio of the resulting test, train and |
| 1298 | + validation sets. |
| 1299 | + stratify_by : str | None, default=None |
| 1300 | + Defines if the training, testing and validation sets should be |
| 1301 | + stratified. Any value other than None indicates stratification |
| 1302 | + and defines which drug response value should be used as basis |
| 1303 | + for the stratification. _None_ indicates that no stratfication |
| 1304 | + should be performed. |
| 1305 | + random_state : int | RandomState | None, defaul=None |
| 1306 | + Defines a seed value for the randomization of the splits. Will |
| 1307 | + get passed to internal functions. Providing the seed will enable |
| 1308 | + reproducability of the generated splits. |
| 1309 | + **kwargs |
| 1310 | + Additional keyword arguments that will be passed to the function |
| 1311 | + that generates classes for the stratification |
| 1312 | + (see also ``_create_classes``). |
| 1313 | + |
| 1314 | + Returns |
| 1315 | + ------- |
| 1316 | + Splits : Split |
| 1317 | + A ``Split`` object that contains three Dataset objects as |
| 1318 | + attributes (``Split.train``, ``Split.test``, |
| 1319 | + ``Split.validate``) |
| 1320 | +
|
| 1321 | + Raises |
| 1322 | + ------- |
| 1323 | + ValueError : |
| 1324 | + If supplied `split_type` is not in the list of accepted values. |
| 1325 | +
|
| 1326 | + """ |
| 1327 | + |
| 1328 | + # reading in the potential keyword arguments that will be passed to |
| 1329 | + # _create_classes(). |
| 1330 | + thresh = kwargs.get('thresh', None) |
| 1331 | + num_classes = kwargs.get('num_classes', 2) |
| 1332 | + quantiles = kwargs.get('quantiles', True) |
| 1333 | + |
| 1334 | + # Type checking split_type |
| 1335 | + if split_type not in [ |
| 1336 | + 'mixed-set', 'drug-blind', 'cancer-blind' |
| 1337 | + ]: |
| 1338 | + raise ValueError( |
| 1339 | + f"{split_type} not an excepted input for 'split_type'" |
| 1340 | + ) |
| 1341 | + |
| 1342 | + # A wide (pivoted) table is more easy to work with in this instance. |
| 1343 | + # The pivot is done using all columns but the 'dose_respones_value' |
| 1344 | + # and 'dose_respones_metric' as index. df.pivot will generate a |
| 1345 | + # MultiIndex which complicates things further down the line. To that |
| 1346 | + # end 'reset_index()' is used to remove the MultiIndex |
| 1347 | + df_full = data.experiments.copy() |
| 1348 | + df_full = df_full.pivot( |
| 1349 | + index = [ |
| 1350 | + 'source', |
| 1351 | + 'improve_sample_id', |
| 1352 | + 'improve_drug_id', |
| 1353 | + 'study', |
| 1354 | + 'time', |
| 1355 | + 'time_unit' |
| 1356 | + ], |
| 1357 | + columns = 'dose_response_metric', |
| 1358 | + values = 'dose_response_value' |
| 1359 | + ).reset_index() |
| 1360 | + |
| 1361 | + # Defining the split sizes. |
| 1362 | + train_size = float(ratio[0]) / sum(ratio) |
| 1363 | + test_val_size = float(ratio[1]) / sum(ratio) |
| 1364 | + |
| 1365 | + # ShuffleSplit is a method/class implemented by scikit-learn that |
| 1366 | + # enables creating splits where the data is shuffled and then |
| 1367 | + # randomly distributed into train and test sets according to the |
| 1368 | + # defined ratio. |
| 1369 | + # |
| 1370 | + # n_splits defines how often a train/test split is generated. |
| 1371 | + # Individual splits (if more than 1 is generated) are not guaranteed |
| 1372 | + # to be disjoint i.e. test sets from individual splits can overlap. |
| 1373 | + # |
| 1374 | + # ShuffleSplit will be used for non stratified mixed-set splitting |
| 1375 | + # since there is no requirement for disjoint groups (i.e. drug / |
| 1376 | + # sample ids). |
| 1377 | + shs = ShuffleSplit( |
| 1378 | + n_splits=1, |
| 1379 | + train_size=train_size, |
| 1380 | + test_size=test_val_size, |
| 1381 | + random_state=random_state |
| 1382 | + ) |
| 1383 | + |
| 1384 | + # GroupShuffleSplit is an extension to ShuffleSplit that also |
| 1385 | + # factors in a group that is used to generate disjoint train and |
| 1386 | + # test sets, e.g. in this particular case the drug or sample id to |
| 1387 | + # generate drug-blind or sample-blind train and test sets. |
| 1388 | + # |
| 1389 | + # GroupShuffleSplit will be used for non stratified drug-/sample- |
| 1390 | + # blind splitting, i.e. there is a requirement that instances from |
| 1391 | + # one group (e.g. a specific drug) are only present in the training |
| 1392 | + # set but not in the test set. |
| 1393 | + gss = GroupShuffleSplit( |
| 1394 | + n_splits=1, |
| 1395 | + train_size=train_size, |
| 1396 | + test_size=test_val_size, |
| 1397 | + random_state=random_state |
| 1398 | + ) |
| 1399 | + |
| 1400 | + # StratifiedShuffleSplit is similar to ShuffleSplit with the added |
| 1401 | + # functionality to also stratify the splits according to defined |
| 1402 | + # class labels. |
| 1403 | + # |
| 1404 | + # StratifiedShuffleSplit will be used for stratified mixed-set |
| 1405 | + # train/test/validate sets. |
| 1406 | + |
| 1407 | + sss = StratifiedShuffleSplit( |
| 1408 | + n_splits=1, |
| 1409 | + train_size=train_size, |
| 1410 | + test_size=test_val_size, |
| 1411 | + random_state=random_state |
| 1412 | + ) |
| 1413 | + |
| 1414 | + # StratifiedGroupKFold generates K folds that take the group into |
| 1415 | + # account when generating folds, i.e. a group will only be present |
| 1416 | + # in one fold. It further tries to stratify the folds based on the |
| 1417 | + # defined classes. |
| 1418 | + # |
| 1419 | + # StratifiedGroupKFold will be used for stratified drug-/sample- |
| 1420 | + # blind splitting. |
| 1421 | + # |
| 1422 | + # The way the K folds are utilized is to combine i, j, & k folds |
| 1423 | + # (according to the defined ratio) into training, testing and |
| 1424 | + # validation sets. |
| 1425 | + sgk = StratifiedGroupKFold( |
| 1426 | + n_splits=sum(ratio), |
| 1427 | + shuffle=True, |
| 1428 | + random_state=random_state |
| 1429 | + ) |
| 1430 | + |
| 1431 | + # The "actual" splitting logic using the defined Splitters as above |
| 1432 | + # follows here starting with the non-stratified splitting: |
| 1433 | + if stratify_by is None: |
| 1434 | + if split_type == 'mixed-set': |
| 1435 | + # Using ShuffleSplit to generate randomized train and |
| 1436 | + # 'other' set, since there is no need for grouping. |
| 1437 | + idx1, idx2 = next( |
| 1438 | + shs.split(df_full) |
| 1439 | + ) |
| 1440 | + elif split_type == 'drug-blind': |
| 1441 | + # Using GroupShuffleSplit to created disjoint train and |
| 1442 | + # 'other' sets by drug id |
| 1443 | + idx1, idx2 = next( |
| 1444 | + gss.split(df_full, groups=df_full.improve_drug_id) |
| 1445 | + ) |
| 1446 | + elif split_type == 'cancer-blind': |
| 1447 | + # same as above we just group over the sample id |
| 1448 | + idx1, idx2 = next( |
| 1449 | + gss.split(df_full, groups=df_full.improve_sample_id) |
| 1450 | + ) |
| 1451 | + else: |
| 1452 | + raise Exception(f"Should be unreachable") |
| 1453 | + |
| 1454 | + # generate new DFs containing the subset of items extracted for |
| 1455 | + # train and other |
| 1456 | + df_train = df_full.iloc[idx1] |
| 1457 | + df_other = df_full.iloc[idx2] |
| 1458 | + |
| 1459 | + |
| 1460 | + # The following block contains the stratified splitting logic |
| 1461 | + else: |
| 1462 | + # First the classes that are needed for the stratification are |
| 1463 | + # generated. `num_classes`, `thresh` and `quantiles` were |
| 1464 | + # previously defined as possible keyword arguments. |
| 1465 | + if 'split_class' not in df_full.columns.to_list(): |
| 1466 | + df_full = _create_classes( |
| 1467 | + data=df_full, |
| 1468 | + metric=stratify_by, |
| 1469 | + num_classes=num_classes, |
| 1470 | + thresh=thresh, |
| 1471 | + quantiles=quantiles, |
| 1472 | + ) |
| 1473 | + if split_type == 'mixed-set': |
| 1474 | + # Using StratifiedShuffleSplit to generate randomized train |
| 1475 | + # and 'other' set, since there is no need for grouping. |
| 1476 | + idx_train, idx_other = next( |
| 1477 | + sss.split(X=df_full, y=df_full['split_class']) |
| 1478 | + ) |
| 1479 | + df_train = df_full.iloc[idx_train] |
| 1480 | + # df_train = df_train.drop(labels=['split_class'], axis=1) |
| 1481 | + df_other = df_full.iloc[idx_other] |
| 1482 | + |
| 1483 | + # using StratifiedGroupKSplit for the stratified drug-/sample- |
| 1484 | + # blind splits. |
| 1485 | + elif split_type == 'drug-blind' or split_type == 'cancer-blind': |
| 1486 | + if split_type == 'drug-blind': |
| 1487 | + splitter = enumerate( |
| 1488 | + sgk.split( |
| 1489 | + X=df_full, |
| 1490 | + y=df_full['split_class'], |
| 1491 | + groups=df_full.improve_drug_id |
| 1492 | + ) |
| 1493 | + ) |
| 1494 | + elif split_type == 'cancer-blind': |
| 1495 | + splitter = enumerate( |
| 1496 | + sgk.split( |
| 1497 | + X=df_full, |
| 1498 | + y=df_full['split_class'], |
| 1499 | + groups=df_full.improve_sample_id |
| 1500 | + ) |
| 1501 | + ) |
| 1502 | + |
| 1503 | + # StratifiedGroupKSplit is setup to generate K splits where |
| 1504 | + # K=sum(ratios) (e.g. 10 if ratio=8:1:1). To obtain three |
| 1505 | + # sets (train/test/validate) the individual splits need to |
| 1506 | + # be combined (e.g. k=[1:8] -> train, k=9 -> test, k=10 -> |
| 1507 | + # validate). The code block below does that by combining |
| 1508 | + # all indices (row numbers) that go into individual sets and |
| 1509 | + # then extracting and adding those rows into the individual |
| 1510 | + # sets. |
| 1511 | + idx_train = [] |
| 1512 | + idx_other = [] |
| 1513 | + for i, (idx1, idx2) in splitter: |
| 1514 | + if i < ratio[0]: |
| 1515 | + idx_train.extend(idx2) |
| 1516 | + elif i >= ratio[0]: |
| 1517 | + idx_other.extend(idx2) |
| 1518 | + # df_full.drop(labels=['split_class'], axis=1, inplace=True) |
| 1519 | + df_train = df_full.iloc[idx_train] |
| 1520 | + df_other = df_full.iloc[idx_other] |
| 1521 | + else: |
| 1522 | + raise Exception(f"Should be unreachable") |
| 1523 | + |
| 1524 | + |
| 1525 | + # generating filtered CoderData objects that contain only the |
| 1526 | + # respective data for each split |
| 1527 | + data_train = _filter(data, df_train) |
| 1528 | + data_other = _filter(data, df_other) |
| 1529 | + |
| 1530 | + return (data_train, data_other) |
0 commit comments