@@ -241,7 +241,10 @@ def _map_items_to_workers_sequentially(num_workers: int, user_items: List[Any])
241241
242242
243243def _map_items_to_workers_weighted (
244- num_workers : int , user_items : List [Any ], weights : Optional [List [int ]] = None
244+ num_workers : int ,
245+ user_items : List [Any ],
246+ weights : Optional [List [int ]] = None ,
247+ file_size : bool = True ,
245248) -> List [List [Any ]]:
246249 # Associate the items to the workers based on number of nodes and node rank.
247250 weights = [1 ] * len (user_items ) if weights is None else weights
@@ -255,7 +258,11 @@ def _map_items_to_workers_weighted(
255258 for worker_id , size in worker_weights .items ():
256259 if worker_id not in worker_ids_this_node :
257260 continue
258- print (f"Worker { worker_id } gets { size / 1e6 :.1f} MB ({ len (worker_items [worker_id ])} files)" )
261+
262+ if file_size :
263+ print (f"Worker { worker_id } gets { size / 1e6 :.1f} MB ({ len (worker_items [worker_id ])} files)" )
264+ else :
265+ print (f"Worker { worker_id } gets ({ len (worker_items [worker_id ])} ) items for a total weight of { size } ." )
259266
260267 return [worker_items [worker_id ] for worker_id in worker_ids_this_node ]
261268
@@ -769,6 +776,7 @@ def __init__(
769776 fast_dev_run : Optional [Union [bool , int ]] = None ,
770777 random_seed : Optional [int ] = 42 ,
771778 reorder_files : bool = True ,
779+ weights : Optional [List [int ]] = None ,
772780 ):
773781 """The `DatasetOptimiser` provides an efficient way to process data across multiple machine into chunks to make
774782 training faster.
@@ -784,6 +792,8 @@ def __init__(
784792 random_seed: The random seed to be set before shuffling the data.
785793 reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
786794 Set this to ``False`` if the order in which samples are processed should be preserved.
795+ weights: Provide a list of weights associated to the inputs.
796+ This is used to evenly split the work among the workers.
787797
788798 """
789799 self .input_dir = _resolve_dir (input_dir )
@@ -799,6 +809,7 @@ def __init__(
799809 self .error_queue : Queue = Queue ()
800810 self .stop_queues : List [Queue ] = []
801811 self .reorder_files = reorder_files
812+ self .weights = weights
802813
803814 # Ensure the input dir is the same across all nodes
804815 self .input_dir = broadcast_object ("input_dir" , self .input_dir )
@@ -827,7 +838,14 @@ def run(self, data_recipe: DataRecipe) -> None:
827838 if not isinstance (user_items , list ):
828839 raise ValueError ("The `prepare_structure` should return a list of item metadata." )
829840
830- if self .reorder_files and self .input_dir .path :
841+ if self .weights is not None :
842+ if len (self .weights ) != len (user_items ):
843+ raise ValueError ("The provided weights length should match the inputs' length." )
844+ workers_user_items = _map_items_to_workers_weighted (
845+ num_workers = self .num_workers , user_items = user_items , weights = self .weights , file_size = False
846+ )
847+
848+ elif self .reorder_files and self .input_dir .path :
831849 # TODO: Only do this on node 0, and broadcast the item sizes to the other nodes.
832850 item_sizes = _get_item_filesizes (user_items , base_path = self .input_dir .path )
833851 workers_user_items = _map_items_to_workers_weighted (
0 commit comments