Skip to content

Commit 9bc0bae

Browse files
authored
Parallelize umap with process pools (#221)
- We use a process pool to do the transform - We use another process pool to do writing of inference data out (also impacts infer verb) - Some of the loudest warnings from the umap package have been supressed
1 parent 32a062b commit 9bc0bae

File tree

2 files changed

+94
-21
lines changed

2 files changed

+94
-21
lines changed

src/fibad/data_sets/inference_dataset.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from collections.abc import Generator
3+
from multiprocessing import Pool
34
from pathlib import Path
45
from typing import Optional, Union
56

@@ -168,6 +169,7 @@ def __init__(self, result_dir: Union[str, Path]):
168169

169170
self.all_ids = np.array([], dtype=np.int64)
170171
self.all_batch_nums = np.array([], dtype=np.int64)
172+
self.writer_pool = Pool()
171173

172174
def write_batch(self, ids: np.ndarray, tensors: list[np.ndarray]):
173175
"""Write a batch of tensors into the dataset. This writes the whole batch immediately.
@@ -197,22 +199,47 @@ def write_batch(self, ids: np.ndarray, tensors: list[np.ndarray]):
197199
if savepath.exists():
198200
RuntimeError(f"Writing objects in batch {self.batch_index} but {filename} already exists.")
199201

200-
np.save(savepath, structured_batch, allow_pickle=False)
202+
self.writer_pool.apply_async(
203+
func=np.save, args=(savepath, structured_batch), kwds={"allow_pickle": False}
204+
)
205+
201206
self.all_ids = np.append(self.all_ids, ids)
202207
self.all_batch_nums = np.append(self.all_batch_nums, np.full(batch_len, self.batch_index))
203208

204209
self.batch_index += 1
205210

206211
def write_index(self):
207-
"""Writes out the batch index built up by this object over multiple write_batch calls."""
212+
"""Writes out the batch index built up by this object over multiple write_batch calls.
213+
See save_batch_index for details.
214+
"""
215+
# First ensure we are done writing out all batches
216+
self.writer_pool.close()
217+
self.writer_pool.join()
218+
219+
# Then write out the batch index.
220+
InferenceDataSetWriter.save_batch_index(self.result_dir, self.all_ids, self.all_batch_nums)
221+
222+
@staticmethod
223+
def save_batch_index(result_dir: Path, all_ids: np.ndarray, all_batch_nums: np.ndarray):
224+
"""Save a batch index in the result directory provided
225+
226+
Parameters
227+
----------
228+
result_dir : Path
229+
The results directory
230+
all_ids : np.ndarray
231+
All IDs to write out.
232+
all_batch_nums : np.ndarray
233+
The corresponding batch numbers for the IDs provided.
234+
"""
208235
batch_index_dtype = np.dtype([("id", np.int64), ("batch_num", np.int64)])
209-
batch_index = np.zeros(len(self.all_ids), batch_index_dtype)
210-
batch_index["id"] = np.array(self.all_ids)
211-
batch_index["batch_num"] = np.array(self.all_batch_nums)
236+
batch_index = np.zeros(len(all_ids), batch_index_dtype)
237+
batch_index["id"] = np.array(all_ids)
238+
batch_index["batch_num"] = np.array(all_batch_nums)
212239
batch_index.sort(order="id")
213240

214241
filename = "batch_index.npy"
215-
savepath = self.result_dir / filename
242+
savepath = result_dir / filename
216243
if savepath.exists():
217244
RuntimeError("The path to save batch index already exists.")
218245
np.save(savepath, batch_index, allow_pickle=False)

src/fibad/verbs/umap.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import logging
22
import pickle
3+
import warnings
34
from argparse import ArgumentParser, Namespace
5+
from multiprocessing import cpu_count
46
from pathlib import Path
57
from typing import Optional, Union
68

@@ -60,13 +62,21 @@ def run(self, input_dir: Optional[Union[Path, str]] = None):
6062
None
6163
The method does not return anything but saves the UMAP representations to disk.
6264
"""
65+
with warnings.catch_warnings():
66+
warnings.simplefilter(action="ignore", category=FutureWarning)
67+
return self._run(input_dir)
68+
69+
def _run(self, input_dir: Optional[Union[Path, str]] = None):
70+
"""See run()"""
71+
from multiprocessing import Pool
72+
6373
import umap
6474
from tqdm.auto import tqdm
6575

6676
from fibad.config_utils import create_results_dir
6777
from fibad.data_sets.inference_dataset import InferenceDataSet, InferenceDataSetWriter
6878

69-
reducer = umap.UMAP(**self.config["umap.UMAP"])
79+
self.reducer = umap.UMAP(**self.config["umap.UMAP"])
7080

7181
# Set up the results directory where we will store our umapped output
7282
results_dir = create_results_dir(self.config, "umap")
@@ -87,29 +97,65 @@ def run(self, input_dir: Optional[Union[Path, str]] = None):
8797
data_sample = inference_results[index_choices].numpy().reshape((sample_size, -1))
8898

8999
# Fit a single reducer on the sampled data
90-
reducer.fit(data_sample)
100+
self.reducer.fit(data_sample)
91101

92102
# Save the reducer to our results directory
93103
with open(results_dir / "umap.pickle", "wb") as f:
94-
pickle.dump(reducer, f)
104+
pickle.dump(self.reducer, f)
95105

96106
# Run all data through the reducer in batches, writing it out as we go.
97107
batch_size = self.config["data_loader"]["batch_size"]
98108
num_batches = int(np.ceil(total_length / batch_size))
99109

100110
all_indexes = np.arange(0, total_length)
101111
all_ids = np.array([int(i) for i in inference_results.ids()])
102-
for batch_indexes in tqdm(
103-
np.array_split(all_indexes, num_batches),
104-
desc="Creating Lower Dimensional Representation using UMAP",
105-
total=num_batches,
106-
):
107-
# We flatten all dimensions of the input array except the dimension
108-
# corresponding to batch elements. This ensures that all inputs to
109-
# the UMAP algorithm are flattend per input item in the batch
110-
batch = inference_results[batch_indexes].reshape(len(batch_indexes), -1)
111-
batch_ids = all_ids[batch_indexes]
112-
transformed_batch = reducer.transform(batch)
113-
umap_results.write_batch(batch_ids, transformed_batch)
112+
113+
# Process pool to do all the transforms
114+
with Pool(processes=cpu_count()) as pool:
115+
# Generator expression that gives a batch tuple composed of:
116+
# batch ids, inference results
117+
args = (
118+
(
119+
all_ids[batch_indexes],
120+
# We flatten all dimensions of the input array except the dimension
121+
# corresponding to batch elements. This ensures that all inputs to
122+
# the UMAP algorithm are flattend per input item in the batch
123+
inference_results[batch_indexes].reshape(len(batch_indexes), -1),
124+
)
125+
for batch_indexes in np.array_split(all_indexes, num_batches)
126+
)
127+
128+
# iterate over the mapped results to write out the umapped points
129+
# imap returns results as they complete so writing should complete in parallel for large datasets
130+
for batch_ids, transformed_batch in tqdm(
131+
pool.imap(self._transform_batch, args),
132+
desc="Creating LowerDimensional Representation using UMAP:",
133+
total=num_batches,
134+
):
135+
logger.debug("Writing a batch out async...")
136+
umap_results.write_batch(batch_ids, transformed_batch)
114137

115138
umap_results.write_index()
139+
140+
def _transform_batch(self, batch_tuple: tuple):
141+
"""Private helper to transform a single batch
142+
143+
Parameters
144+
----------
145+
batch_tuple : tuple()
146+
first element is the IDs of the batch as a numpy array
147+
second element is the inference results to transform as a numpy array with shape (batch_len, N)
148+
where N is the total number of dimensions in the inference result. Caller flattens all inference
149+
result axes for us.
150+
151+
Returns
152+
-------
153+
tuple
154+
first element is the ids of the batch as a numpy array
155+
second element is the results of running the umap transform on the input as a numpy array.
156+
"""
157+
batch_ids, batch = batch_tuple
158+
with warnings.catch_warnings():
159+
warnings.simplefilter(action="ignore", category=FutureWarning)
160+
logger.debug("Transforming a batch ...")
161+
return (batch_ids, self.reducer.transform(batch))

0 commit comments

Comments
 (0)