11import logging
22import pickle
3+ import warnings
34from argparse import ArgumentParser , Namespace
5+ from multiprocessing import cpu_count
46from pathlib import Path
57from typing import Optional , Union
68
@@ -60,13 +62,21 @@ def run(self, input_dir: Optional[Union[Path, str]] = None):
6062 None
6163 The method does not return anything but saves the UMAP representations to disk.
6264 """
65+ with warnings .catch_warnings ():
66+ warnings .simplefilter (action = "ignore" , category = FutureWarning )
67+ return self ._run (input_dir )
68+
69+ def _run (self , input_dir : Optional [Union [Path , str ]] = None ):
70+ """See run()"""
71+ from multiprocessing import Pool
72+
6373 import umap
6474 from tqdm .auto import tqdm
6575
6676 from fibad .config_utils import create_results_dir
6777 from fibad .data_sets .inference_dataset import InferenceDataSet , InferenceDataSetWriter
6878
69- reducer = umap .UMAP (** self .config ["umap.UMAP" ])
79+ self . reducer = umap .UMAP (** self .config ["umap.UMAP" ])
7080
7181 # Set up the results directory where we will store our umapped output
7282 results_dir = create_results_dir (self .config , "umap" )
@@ -87,29 +97,65 @@ def run(self, input_dir: Optional[Union[Path, str]] = None):
8797 data_sample = inference_results [index_choices ].numpy ().reshape ((sample_size , - 1 ))
8898
8999 # Fit a single reducer on the sampled data
90- reducer .fit (data_sample )
100+ self . reducer .fit (data_sample )
91101
92102 # Save the reducer to our results directory
93103 with open (results_dir / "umap.pickle" , "wb" ) as f :
94- pickle .dump (reducer , f )
104+ pickle .dump (self . reducer , f )
95105
96106 # Run all data through the reducer in batches, writing it out as we go.
97107 batch_size = self .config ["data_loader" ]["batch_size" ]
98108 num_batches = int (np .ceil (total_length / batch_size ))
99109
100110 all_indexes = np .arange (0 , total_length )
101111 all_ids = np .array ([int (i ) for i in inference_results .ids ()])
102- for batch_indexes in tqdm (
103- np .array_split (all_indexes , num_batches ),
104- desc = "Creating Lower Dimensional Representation using UMAP" ,
105- total = num_batches ,
106- ):
107- # We flatten all dimensions of the input array except the dimension
108- # corresponding to batch elements. This ensures that all inputs to
109- # the UMAP algorithm are flattend per input item in the batch
110- batch = inference_results [batch_indexes ].reshape (len (batch_indexes ), - 1 )
111- batch_ids = all_ids [batch_indexes ]
112- transformed_batch = reducer .transform (batch )
113- umap_results .write_batch (batch_ids , transformed_batch )
112+
113+ # Process pool to do all the transforms
114+ with Pool (processes = cpu_count ()) as pool :
115+ # Generator expression that gives a batch tuple composed of:
116+ # batch ids, inference results
117+ args = (
118+ (
119+ all_ids [batch_indexes ],
120+ # We flatten all dimensions of the input array except the dimension
121+ # corresponding to batch elements. This ensures that all inputs to
122+ # the UMAP algorithm are flattend per input item in the batch
123+ inference_results [batch_indexes ].reshape (len (batch_indexes ), - 1 ),
124+ )
125+ for batch_indexes in np .array_split (all_indexes , num_batches )
126+ )
127+
128+ # iterate over the mapped results to write out the umapped points
129+ # imap returns results as they complete so writing should complete in parallel for large datasets
130+ for batch_ids , transformed_batch in tqdm (
131+ pool .imap (self ._transform_batch , args ),
132+ desc = "Creating LowerDimensional Representation using UMAP:" ,
133+ total = num_batches ,
134+ ):
135+ logger .debug ("Writing a batch out async..." )
136+ umap_results .write_batch (batch_ids , transformed_batch )
114137
115138 umap_results .write_index ()
139+
140+ def _transform_batch (self , batch_tuple : tuple ):
141+ """Private helper to transform a single batch
142+
143+ Parameters
144+ ----------
145+ batch_tuple : tuple()
146+ first element is the IDs of the batch as a numpy array
147+ second element is the inference results to transform as a numpy array with shape (batch_len, N)
148+ where N is the total number of dimensions in the inference result. Caller flattens all inference
149+ result axes for us.
150+
151+ Returns
152+ -------
153+ tuple
154+ first element is the ids of the batch as a numpy array
155+ second element is the results of running the umap transform on the input as a numpy array.
156+ """
157+ batch_ids , batch = batch_tuple
158+ with warnings .catch_warnings ():
159+ warnings .simplefilter (action = "ignore" , category = FutureWarning )
160+ logger .debug ("Transforming a batch ..." )
161+ return (batch_ids , self .reducer .transform (batch ))
0 commit comments