3
3
import signal
4
4
import traceback
5
5
import types
6
- from abc import ABC , abstractmethod
7
6
from enum import Enum
8
7
from multiprocessing import Process , Queue
9
8
from pathlib import Path
10
9
from queue import Empty
11
10
from shutil import copyfile
11
+ from textwrap import dedent
12
12
from threading import Thread
13
13
from time import sleep , time
14
- from typing import Any , Callable , Dict , List , Literal , Optional , Tuple
14
+ from typing import Any , Callable , Dict , List , Literal , Optional , Protocol , Tuple , TypeVar , runtime_checkable
15
15
from urllib import parse
16
16
17
17
from tqdm .auto import tqdm
@@ -167,7 +167,7 @@ def __init__(
167
167
start_index : int ,
168
168
dataset_name : str ,
169
169
node_rank : int ,
170
- dataset_optimizer : "DatasetOptimizer" ,
170
+ prepare_item : Callable ,
171
171
src_dir : str ,
172
172
remote_src_dir : str ,
173
173
remote_dst_dir : Optional [str ],
@@ -187,7 +187,7 @@ def __init__(
187
187
self .start_index = start_index
188
188
self .dataset_name = dataset_name
189
189
self .node_rank = node_rank
190
- self .prepare_item = dataset_optimizer . prepare_item
190
+ self .prepare_item = prepare_item
191
191
self .src_dir = src_dir
192
192
self .remote_src_dir = remote_src_dir
193
193
self .remote_dst_dir = remote_dst_dir
@@ -432,57 +432,21 @@ class WorkerType(Enum):
432
432
PROCESS = "process"
433
433
434
434
435
- class DatasetOptimizer (ABC ):
436
- @abstractmethod
437
- def prepare_dataset_structure (self , src_dir : str , filepaths : List [str ]) -> List [Any ]:
438
- """This function is meant to return a list of item metadata. Each item metadata should be enough to prepare a
439
- single item when called with the prepare_item.
435
+ T = TypeVar ("T" )
440
436
441
- Example::
442
437
443
- # For a classification use case
444
-
445
- def prepare_dataset_structure(self, src_dir, filepaths)
446
- import numpy as np
447
-
448
- filepaths = ['class_a/file_1.ext', ..., 'class_b/file_1.ext', ...]
449
- classes = np.unique([filepath.split("/")[0] for filepath in filepaths])
450
- classes_to_idx_map = {c: idx for idx, c in enumerate(classes)}
451
-
452
- # Return pair with the filepath to the obj and its class
453
- # [('class_a/file_1.ext', 0), ... ('class_b/file_1.ext', 1)]
454
- return [(filepath, classes_to_idx_map[filepath.split("/")[0]]) for filepath in filepaths]
455
-
456
- Example::
457
-
458
- # For a image segmentation use case
459
-
460
- def prepare_dataset_structure(self, src_dir, filepaths)
461
- import numpy as np
462
-
463
- filepaths = ['file_1.JPEG', 'file_1.mask', .... 'file_N.JPEG', 'file_N.mask', ...]
464
-
465
- # [('file_1.JPEG', 'file_1.mask'), ... ('file_N.JPEG', 'file_N.mask')]
466
- return [(x[i], x[i+1]) for i in range(len(filepaths) -1)]
467
-
468
- def prepare_item(self, obj):
469
- image_filepath, mask_filepath = obj
470
-
471
- image = load_and_resize(image_filepath)
472
- mask = load_and_resize(mask_filepath)
473
- return (image, mask)
474
-
475
- """
438
+ @runtime_checkable
439
+ class _OptimizableDataset (Protocol ):
440
+ @staticmethod
441
+ def prepare_dataset_structure (root : str , filepaths : List [str ]) -> List [T ]:
476
442
pass
477
443
478
- def prepare_item (self , metadata_item : Any ) -> Any :
479
- """Using some metadata, prepare the associated item.
444
+ @staticmethod
445
+ def prepare_item (item_metadata : T ) -> Any :
446
+ return item_metadata
480
447
481
- The output of this function will be binarised
482
-
483
- """
484
- return metadata_item
485
448
449
+ class DatasetOptimizer :
486
450
def __init__ (
487
451
self ,
488
452
name : str ,
@@ -547,9 +511,29 @@ def __init__(
547
511
)
548
512
self .random_seed = random_seed
549
513
550
- def run (self ) -> None :
514
+ def run (self , optimizable_dataset : _OptimizableDataset ) -> None :
551
515
"""The `DatasetChunker.run(...)` method is used to trigger the data processing from your dataset into
552
516
chunks."""
517
+ if not isinstance (optimizable_dataset , _OptimizableDataset ):
518
+ raise ValueError (
519
+ dedent (
520
+ """The provided argument to the DatasetOptimizer.run(...) needs to have the following format:
521
+
522
+ Example:
523
+
524
+ class YourDataset:
525
+
526
+ @staticmethod
527
+ def prepare_dataset_structure(root: str, filepaths: List[str]) -> List[T]:
528
+ return [...]
529
+
530
+ @staticmethod
531
+ def prepare_item(item_metadata: T) -> Any:
532
+ return ...
533
+ """
534
+ )
535
+ )
536
+
553
537
t0 = time ()
554
538
print (f"Setup started for `{ self .name } ` with fast_dev_run={ self .fast_dev_run } ." )
555
539
@@ -564,7 +548,7 @@ def run(self) -> None:
564
548
seed_everything (self .random_seed )
565
549
566
550
# Call the setup method of the user
567
- user_items = self .prepare_dataset_structure (self .src_dir , filepaths )
551
+ user_items : List [ Any ] = optimizable_dataset .prepare_dataset_structure (self .src_dir , filepaths )
568
552
569
553
if not isinstance (user_items , list ):
570
554
raise ValueError ("The setup_fn should return a list of item metadata." )
@@ -588,9 +572,9 @@ def run(self) -> None:
588
572
signal .signal (signal .SIGINT , self ._signal_handler )
589
573
590
574
if self .worker_type == WorkerType .THREAD .value :
591
- self ._create_thread_workers (begins , workers_user_items )
575
+ self ._create_thread_workers (optimizable_dataset , begins , workers_user_items )
592
576
else :
593
- self ._create_process_workers (begins , workers_user_items )
577
+ self ._create_process_workers (optimizable_dataset , begins , workers_user_items )
594
578
595
579
print ("Workers are ready ! Starting data processing..." )
596
580
@@ -634,7 +618,9 @@ def _exit_on_error(self, error: str) -> None:
634
618
w .join (0 )
635
619
raise RuntimeError (f"We found the following error { error } ." )
636
620
637
- def _create_thread_workers (self , begins : List [int ], workers_user_items : List [List [Any ]]) -> None :
621
+ def _create_thread_workers (
622
+ self , optimizable_dataset : _OptimizableDataset , begins : List [int ], workers_user_items : List [List [Any ]]
623
+ ) -> None :
638
624
current_total = 0
639
625
total = sum ([len (w ) for w in workers_user_items ])
640
626
with tqdm (total = total , smoothing = 0 ) as pbar :
@@ -649,7 +635,7 @@ def _create_thread_workers(self, begins: List[int], workers_user_items: List[Lis
649
635
begins [worker_idx ],
650
636
self .name ,
651
637
_get_node_rank (),
652
- self ,
638
+ optimizable_dataset . prepare_item ,
653
639
self .src_dir ,
654
640
self .remote_src_dir ,
655
641
self .remote_dst_dir ,
@@ -676,7 +662,9 @@ def _create_thread_workers(self, begins: List[int], workers_user_items: List[Lis
676
662
if current_total == total :
677
663
break
678
664
679
- def _create_process_workers (self , begins : List [int ], workers_user_items : List [List [Any ]]) -> None :
665
+ def _create_process_workers (
666
+ self , optimizable_dataset : _OptimizableDataset , begins : List [int ], workers_user_items : List [List [Any ]]
667
+ ) -> None :
680
668
self .progress_queue = Queue ()
681
669
workers : List [DataWorkerProcess ] = []
682
670
stop_queues : List [Queue ] = []
@@ -688,7 +676,7 @@ def _create_process_workers(self, begins: List[int], workers_user_items: List[Li
688
676
begins [worker_idx ],
689
677
self .name ,
690
678
_get_node_rank (),
691
- self ,
679
+ optimizable_dataset . prepare_item ,
692
680
self .src_dir ,
693
681
self .remote_src_dir ,
694
682
self .remote_dst_dir ,
0 commit comments