|
133 | 133 | "import warnings\n", |
134 | 134 | "from collections import OrderedDict\n", |
135 | 135 | "from pathlib import Path\n", |
136 | | - "from typing import TYPE_CHECKING, Callable\n", |
| 136 | + "from typing import TYPE_CHECKING\n", |
137 | 137 | "\n", |
138 | 138 | "# Third party imports\n", |
139 | 139 | "import joblib\n", |
|
192 | 192 | ")\n", |
193 | 193 | "\n", |
194 | 194 | "if TYPE_CHECKING: # pragma: no cover\n", |
195 | | - " from collections.abc import Iterator\n", |
| 195 | + " from collections.abc import Callable, Iterator\n", |
196 | 196 | "\n", |
197 | 197 | "warnings.filterwarnings(\"ignore\")\n", |
198 | 198 | "mpl.rcParams[\"figure.dpi\"] = 300 # for high resolution figure in notebook" |
|
394 | 394 | "patient_uids = patient_uids[sel]\n", |
395 | 395 | "patient_labels = patient_labels_[sel]\n", |
396 | 396 | "assert len(patient_uids) == len(patient_labels) # noqa: S101\n", |
397 | | - "clinical_info = OrderedDict(list(zip(patient_uids, patient_labels)))\n", |
| 397 | + "clinical_info = OrderedDict(list(zip(patient_uids, patient_labels, strict=False)))\n", |
398 | 398 | "\n", |
399 | 399 | "# Retrieve patient code of each WSI, this is based on TCGA barcodes:\n", |
400 | 400 | "# https://docs.gdc.cancer.gov/Encyclopedia/pages/TCGA_Barcode/\n", |
|
412 | 412 | "wsi_names = np.array(wsi_names)[sel]\n", |
413 | 413 | "wsi_labels = np.array(wsi_labels)[sel]\n", |
414 | 414 | "\n", |
415 | | - "label_df = list(zip(wsi_names, wsi_labels))\n", |
| 415 | + "label_df = list(zip(wsi_names, wsi_labels, strict=False))\n", |
416 | 416 | "label_df = pd.DataFrame(label_df, columns=[\"WSI-CODE\", \"LABEL\"])" |
417 | 417 | ] |
418 | 418 | }, |
|
529 | 529 | "\n", |
530 | 530 | " splits.append(\n", |
531 | 531 | " {\n", |
532 | | - " \"train\": list(zip(train_x, train_y)),\n", |
533 | | - " \"valid\": list(zip(valid_x, valid_y)),\n", |
534 | | - " \"test\": list(zip(test_x, test_y)),\n", |
| 532 | + " \"train\": list(zip(train_x, train_y, strict=False)),\n", |
| 533 | + " \"valid\": list(zip(valid_x, valid_y, strict=False)),\n", |
| 534 | + " \"test\": list(zip(test_x, test_y, strict=False)),\n", |
535 | 535 | " },\n", |
536 | 536 | " )\n", |
537 | 537 | " return splits" |
|
2025 | 2025 | " output = [np.split(v, batch_size, axis=0) for v in output]\n", |
2026 | 2026 | " # pairing such that it will be\n", |
2027 | 2027 | " # N batch size x H head list\n", |
2028 | | - " output = list(zip(*output))\n", |
| 2028 | + " output = list(zip(*output, strict=False))\n", |
2029 | 2029 | " step_output.extend(output)\n", |
2030 | 2030 | " pbar.update()\n", |
2031 | 2031 | " pbar.close()\n", |
|
2042 | 2042 | " ):\n", |
2043 | 2043 | " # Expand the list of N dataset size x H heads\n", |
2044 | 2044 | " # back to a list of H Head each with N samples.\n", |
2045 | | - " output = list(zip(*step_output))\n", |
| 2045 | + " output = list(zip(*step_output, strict=False))\n", |
2046 | 2046 | " logit, true = output\n", |
2047 | 2047 | " logit = np.squeeze(np.array(logit))\n", |
2048 | 2048 | " true = np.squeeze(np.array(true))\n", |
|
0 commit comments