diff --git a/python/coverage.svg b/python/coverage.svg index 59d64b3..607d3de 100644 --- a/python/coverage.svg +++ b/python/coverage.svg @@ -15,7 +15,7 @@ coverage coverage - 93% - 93% + 91% + 91% diff --git a/python/ouroboros/helpers/mem.py b/python/ouroboros/helpers/mem.py index 62b0944..73a5023 100644 --- a/python/ouroboros/helpers/mem.py +++ b/python/ouroboros/helpers/mem.py @@ -1,7 +1,7 @@ from dataclasses import astuple, replace, asdict, fields from functools import cached_property from multiprocessing.shared_memory import SharedMemory, ShareableList -from multiprocessing.managers import SharedMemoryManager, BaseManager +from multiprocessing.managers import SharedMemoryManager, BaseManager, ListProxy from sys import stdout from time import sleep from typing import TextIO @@ -20,42 +20,6 @@ def is_advanced_index(index): return isinstance(index, (list, np.ndarray)) and not isinstance(index, slice) -class SharedNPManager(SharedMemoryManager): - """ Manages shared memory numpy arrays. """ - def __init__(self, *args, - queue_mem: list[tuple[DataShape, np.dtype] | tuple[tuple[DataShape, np.dtype]]] = [], - **kwargs): - SharedMemoryManager.__init__(self, *args, **kwargs) - - self.__mem_queue = [] - for mem in queue_mem: - if isinstance(mem[0], tuple): - self.__mem_queue.append([mem[0][0], mem[0][1]] + list(mem[1:])) - else: - self.__mem_queue.append([mem[0], mem[1]]) - - def SharedNPArray(self, shape: DataShape, dtype: np.dtype, *create_with: tuple[DataShape, np.dtype]): - full_set = [(shape, dtype)] + list(create_with) - size = max([np.prod(astuple(shape), dtype=object) * np.dtype(dtype).itemsize for (shape, dtype) in full_set]) - mem = self.SharedMemory(int(size)) - result = [SharedNPArray(mem.name, shape, dtype) for (shape, dtype) in full_set] - return result[0] if len(result) == 1 else result - - def clear_queue(self): - ar_mem = [] - while len(self.__mem_queue) > 0: - new_mem = self.SharedNPArray(*self.__mem_queue.pop(0)) - ar_mem += new_mem if isinstance(new_mem, list) else [new_mem] - return ar_mem - - def __enter__(self): - this = [BaseManager.__enter__(self)] - return tuple(this + self.clear_queue()) - - def __exit__(self, *args, **kwargs): - super().__exit__(*args, **kwargs) - - class SharedNPArray: def __init__(self, name: str, shape: DataShape, dtype: np.dtype, views: list = None, *, allocate: bool = False): @@ -148,21 +112,95 @@ def shape(self): return shape -def cleanup_mem(*shm_objects): - """ Close and unlink shared memory objects. +_termed_mem = [] - :param shm_objects: Shared memory objects to shut down. - """ - for shm in shm_objects: - if isinstance(shm, SharedNPArray): - shm.shutdown() - del shm - elif isinstance(shm, ShareableList): - shm.shm.close() - shm.shm.unlink() - elif isinstance(shm, SharedMemory): - shm.close() - shm.unlink() + +def get_termed_mem(): + """Returns the existing global list rather than creating a new one.""" + return _termed_mem + + +class SharedNPManager(SharedMemoryManager): + SharedMemoryManager.register( + '_TermedMem', + callable=get_termed_mem, + proxytype=ListProxy + ) + + """ Manages shared memory numpy arrays. """ + def __init__(self, *args, + queue_mem: list[tuple[DataShape, np.dtype] | tuple[tuple[DataShape, np.dtype]]] = [], + **kwargs): + SharedMemoryManager.__init__(self, *args, **kwargs) + + self.__mem_queue = [] + for mem in queue_mem: + if isinstance(mem[0], tuple): + self.__mem_queue.append([mem[0][0], mem[0][1]] + list(mem[1:])) + else: + self.__mem_queue.append([mem[0], mem[1]]) + self.__termed_mem = None + + def SharedNPArray(self, shape: DataShape, dtype: np.dtype, *create_with: tuple[DataShape, np.dtype]): + full_set = [(shape, dtype)] + list(create_with) + size = max([np.prod(astuple(shape), dtype=object) * np.dtype(dtype).itemsize for (shape, dtype) in full_set]) + mem = self.SharedMemory(int(size)) + result = [SharedNPArray(mem.name, shape, dtype) for (shape, dtype) in full_set] + return result[0] if len(result) == 1 else result + + def TermedNPArray(self, shape: DataShape, dtype: np.dtype, *create_with: tuple[DataShape, np.dtype]): + full_set = [(shape, dtype)] + list(create_with) + size = max([np.prod(astuple(shape), dtype=object) * np.dtype(dtype).itemsize for (shape, dtype) in full_set]) + mem = SharedMemory(create=True, size=int(size)) + result = [SharedNPArray(mem.name, shape, dtype) for (shape, dtype) in full_set] + self.__termed_mem.append(mem.name) + return result[0] if len(result) == 1 else result + + def clear_queue(self): + ar_mem = [] + while len(self.__mem_queue) > 0: + new_mem = self.SharedNPArray(*self.__mem_queue.pop(0)) + ar_mem += new_mem if isinstance(new_mem, list) else [new_mem] + return ar_mem + + def remove_termed(self, mem): + if isinstance(mem, SharedNPArray): + name = mem.name + mem.shutdown() + else: + name = mem + if name in self.__termed_mem: + self.__termed_mem.pop(self.__termed_mem.index(name)) + t = SharedMemory(name) + t.close() + t.unlink() + else: + raise FileNotFoundError(f"{name} is not a termed shared memory array. {self.__termed_mem}") + + def shutdown(self): + for name in self.__termed_mem: + t = SharedMemory(name) + t.close() + t.unlink() + super().shutdown() + + def start(self, *args, **kwargs): + super().start(*args, **kwargs) + # Initialize the proxy immediately upon start + self.__termed_mem = self._TermedMem() + + def connect(self): + super().connect() + # Initialize the proxy immediately upon connect + self.__termed_mem = self._TermedMem() + + def __enter__(self): + this = [BaseManager.__enter__(self)] + return tuple(this + self.clear_queue()) + + def __exit__(self, *args, **kwargs): + print("Exiting! SHM!") + super().__exit__(*args, **kwargs) def exit_cleanly(step: str, *shm_objects, return_code: int = 0, statement: str = '', log_level: LOG = LOG.TIME, @@ -191,3 +229,20 @@ def mem_monitor(mem_file, mem_store, pid): last_step = last_step_arr.tobytes().decode() log.write(last_step, out=out, pid=pid) sleep(MEM_INTERVAL_TIMER) + + +def cleanup_mem(*shm_objects): + """ Close and unlink shared memory objects. + + :param shm_objects: Shared memory objects to shut down. + """ + for shm in shm_objects: + if isinstance(shm, SharedNPArray): + shm.shutdown() + del shm + elif isinstance(shm, ShareableList): + shm.shm.close() + shm.shm.unlink() + elif isinstance(shm, SharedMemory): + shm.close() + shm.unlink() diff --git a/python/ouroboros/helpers/volume_cache.py b/python/ouroboros/helpers/volume_cache.py index f9a73d0..c04e6e4 100644 --- a/python/ouroboros/helpers/volume_cache.py +++ b/python/ouroboros/helpers/volume_cache.py @@ -1,7 +1,6 @@ from dataclasses import astuple import os -import sys -import traceback +import time from cloudvolume import CloudVolume, VolumeCutout, Bbox import numpy as np @@ -37,10 +36,7 @@ def __init__( self.volumes = [None] * len(bounding_boxes) self.use_shared = use_shared - - if self.use_shared: - self.shm_host = SharedNPManager() - self.shm_host.__enter__() + self.__shm_host = None # Indicates whether the a volume should be cached after the last slice to request it is processed self.cache_volume = [False] * len(bounding_boxes) @@ -59,6 +55,11 @@ def to_dict(self) -> dict: "flush_cache": self.flush_cache, } + def connect_shm(self, address: str, authkey: str): + self.__shm_host = SharedNPManager(address=address, authkey=authkey) + self.__shm_host.connect() + self.__authkey = authkey + @staticmethod def from_dict(data: dict) -> "VolumeCache": bounding_boxes = [BoundingBox.from_dict(bb) for bb in data["bounding_boxes"]] @@ -131,7 +132,7 @@ def request_volume_for_slice(self, slice_index: int): # Download the volume if it is not already cached if self.volumes[vol_index] is None: - self.download_volume(vol_index, bounding_box) + self.volumes[vol_index] = download_volume(self.cv, bounding_box, mip=self.mip) # Remove the last requested volume if it is not to be cached if ( @@ -153,68 +154,9 @@ def remove_volume(self, volume_index: int, destroy_shared: bool = False): if not self.use_shared: self.volumes[volume_index] = None elif destroy_shared: - self.volumes[volume_index].shutdown() + self.__shm_host.remove_termed(self.volumes[volume_index]) self.volumes[volume_index] = None - def download_volume( - self, volume_index: int, bounding_box: BoundingBox, parallel=False - ) -> VolumeCutout: - bbox = bounding_box.to_cloudvolume_bbox().astype(int) - vol_shape = NGOrder(*bbox.size3(), self.cv.cv.num_channels) - - # Limit size of area we are grabbing, in case we go out of bounds. - dl_box = Bbox.intersection(self.cv.cv.bounds, bbox) - local_min = [int(start) for start in np.subtract(dl_box.minpt, bbox.minpt)] - - local_bounds = np.s_[*[slice(start, stop) for start, stop in - zip(local_min, np.sum([local_min, dl_box.size3()], axis=0))], - :] - - # Download the bounding box volume - if self.use_shared: - volume = self.shm_host.SharedNPArray(vol_shape, np.float32) - with volume as volume_data: - volume_data[:] = 0 # Prob not most efficient but makes math much easier - volume_data[local_bounds] = self.cv.cv.download(dl_box, mip=self.mip, parallel=parallel) - else: - volume = np.zeros(astuple(vol_shape)) - volume[local_bounds] = self.cv.cv.download(dl_box, mip=self.mip, parallel=parallel) - - # Store the volume in the cache - self.volumes[volume_index] = volume - - def create_processing_data(self, volume_index: int, parallel=False): - """ - Generate a data packet for processing a volume. - - Suitable for parallel processing. - - Parameters: - ---------- - volume_index (int): The index of the volume to process. - parallel (bool): Whether to download the volume in parallel (only do parallel if downloading in one thread). - - Returns: - ------- - tuple: A tuple containing the volume data, the bounding box of the volume, - the slice indices associated with the volume, and a function to remove the volume from the cache. - """ - - bounding_box = self.bounding_boxes[volume_index] - - # Download the volume if it is not already cached - if self.volumes[volume_index] is None: - try: - self.download_volume(volume_index, bounding_box, parallel=parallel) - except BaseException as be: - traceback.print_tb(be.__traceback__, file=sys.stderr) - return f"Error downloading data: {be}" - - # Get all slice indices associated with this volume - slice_indices = self.get_slice_indices(volume_index) - - return self.volumes[volume_index], bounding_box, slice_indices, volume_index - def get_slice_indices(self, volume_index: int): return [i for i, v in enumerate(self.link_rects) if v == volume_index] @@ -264,6 +206,38 @@ def flush_cache(self): self.cv.cache.flush() +def download_volume( + cv: CloudVolumeInterface, bounding_box: BoundingBox, mip, parallel=False, + use_shared=False, shm_address: str = None, shm_authkey: str = None, **kwargs +) -> VolumeCutout: + start = time.perf_counter() + bbox = bounding_box.to_cloudvolume_bbox().astype(int) + vol_shape = NGOrder(*bbox.size3(), cv.cv.num_channels) + + # Limit size of area we are grabbing, in case we go out of bounds. + dl_box = Bbox.intersection(cv.cv.bounds, bbox) + local_min = [int(start) for start in np.subtract(dl_box.minpt, bbox.minpt)] + + local_bounds = np.s_[*[slice(start, stop) for start, stop in + zip(local_min, np.sum([local_min, dl_box.size3()], axis=0))], + :] + + # Download the bounding box volume + if use_shared: + shm_host = SharedNPManager(address=shm_address, authkey=shm_authkey) + shm_host.connect() + volume = shm_host.TermedNPArray(vol_shape, np.float32) + with volume as volume_data: + volume_data[:] = 0 # Prob not most efficient but makes math much easier + volume_data[local_bounds] = cv.cv.download(dl_box, mip=mip, parallel=parallel) + else: + volume = np.zeros(astuple(vol_shape)) + volume[local_bounds] = cv.cv.download(dl_box, mip=mip, parallel=parallel) + + # Return volume + return volume, bounding_box, time.perf_counter() - start, *kwargs.values() + + def get_mip_volume_sizes(source_url: str) -> dict: """ Get the volume sizes for all available MIPs. diff --git a/python/ouroboros/pipeline/slice_parallel_pipeline.py b/python/ouroboros/pipeline/slice_parallel_pipeline.py index c395aca..bfdf41c 100644 --- a/python/ouroboros/pipeline/slice_parallel_pipeline.py +++ b/python/ouroboros/pipeline/slice_parallel_pipeline.py @@ -1,5 +1,7 @@ from functools import partial from pathlib import Path +import psutil +import secrets import sys import threading import traceback @@ -8,7 +10,8 @@ coordinate_grid, slice_volume_from_grids ) -from ouroboros.helpers.volume_cache import VolumeCache +from ouroboros.helpers.mem import SharedNPManager +from ouroboros.helpers.volume_cache import VolumeCache, download_volume from ouroboros.helpers.files import ( format_slice_output_file, format_slice_output_multiple, @@ -120,19 +123,21 @@ def _process(self, input_data: tuple[any]) -> None | str: all_work_done = threading.Event() # Minimum and maximum boundaries. -# bound_shm = SharedNPArray("boundaries", X(2), np.float64, allocate=True) -# with bound_shm[:] as boundaries: boundaries = np.zeros(2, dtype=np.float32) + # Set an SharedMemoryManager key so we can pass it around later. + authkey = secrets.token_bytes(32) + # Start the download volumes process and process downloaded volumes as they become available in the queue try: - with concurrent.futures.ThreadPoolExecutor( - max_workers=max(self.num_threads, 4) + with concurrent.futures.ProcessPoolExecutor( + max_workers=self.num_processes // 4 ) as download_executor, concurrent.futures.ProcessPoolExecutor( - max_workers=self.num_processes - ) as process_executor: + max_workers=self.num_processes * 3 // 4 + ) as process_executor, SharedNPManager(authkey=authkey) as (shm_host, ): download_futures = [] process_futures = [] + volume_cache.connect_shm(shm_host.address, authkey) vol_range = list(reversed(range(len(volume_cache.volumes)))) @@ -146,46 +151,69 @@ def _process(self, input_data: tuple[any]) -> None | str: temporary_path=temp_file_path, shared=volume_cache.use_shared) - partial_dl_executor = partial(dl_worker, - volume_cache=volume_cache, - parallel_fetch=(self.num_threads == 1)) + partial_dl_executor = partial(download_volume, + cv=volume_cache.cv, + mip=volume_cache.mip, + parallel=False, + use_shared=volume_cache.use_shared, + shm_address=shm_host.address, + shm_authkey=authkey) def dl_completed(future): + volume, bounding_box, download_time, index = future.result() + self.add_timing("Download Time", download_time) process_futures.append(process_executor.submit(partial_slice_executor, - processing_data=future.result())) + volume=volume, + bounding_box=bounding_box, + slice_indices=volume_cache.get_slice_indices(index), + volume_index=index + )) process_futures[-1].add_done_callback(processor_completed) + self.update_progress( + np.sum([future.done() for future in download_futures]) / (len(volume_cache.volumes) * 4) + + np.sum([future.done() for future in process_futures]) / (len(volume_cache.volumes) * 4 // 3) + ) + if volume_cache.use_shared or volume_cache.cache_volume: + volume_cache.volumes[index] = volume def processor_completed(future): volume_index, durations, min_val, max_val = future.result() -# with bound_shm[:] as boundaries: boundaries[0] = min(boundaries[0], min_val) boundaries[1] = max(boundaries[1], max_val) if volume_cache.use_shared: volume_cache.remove_volume(volume_index, destroy_shared=True) + print(f"Removed; Volume Count: {sum(vol is not None for vol in volume_cache.volumes)}" + f" | {psutil.virtual_memory().available}") for key, value in durations.items(): self.add_timing_list(key, value) - # Update the progress bar - # 1/3 DL, 2/3 Process self.update_progress( - np.sum([future.done() for future in download_futures]) / (len(volume_cache.volumes) * 3) + - np.sum([future.done() for future in process_futures]) / (len(volume_cache.volumes) * 3 / 2) + np.sum([future.done() for future in download_futures]) / (len(volume_cache.volumes) * 4) + + np.sum([future.done() for future in process_futures]) / (len(volume_cache.volumes) * 4 // 3) ) if len(vol_range) > 0: - download_futures.append(download_executor.submit(partial_dl_executor, volume=vol_range.pop())) + index = vol_range.pop() + download_futures.append( + download_executor.submit(partial_dl_executor, + bounding_box=volume_cache.bounding_boxes[index], + volume_index=index)) download_futures[-1].add_done_callback(dl_completed) if self.progress >= 1.0: all_work_done.set() # Download all volumes in parallel, and add the callback to process them as they finish. - for _ in range(np.min([self.num_processes + 4, len(vol_range)])): - download_futures.append(download_executor.submit(partial_dl_executor, volume=vol_range.pop())) + for _ in range(self.num_processes * 3 // 4 + 1): + index = vol_range.pop() + download_futures.append( + download_executor.submit(partial_dl_executor, + bounding_box=volume_cache.bounding_boxes[index], + volume_index=index)) download_futures[-1].add_done_callback(dl_completed) all_work_done.wait() with multiprocessing.pool.ThreadPool(self.num_processes) as pool: - # with bound_shm[:] as boundaries: + start_time = time.perf_counter() convert_func = partial(np_convert, target_dtype=volume_cache.get_volume_dtype(), normalize=config.normalize_output, @@ -205,6 +233,7 @@ def processor_completed(future): i) for i in range(len(temp_file))]) del temp_file temp_file_path.unlink() + self.add_timing("Rewrite Temp", time.perf_counter() - start_time) except BaseException as e: traceback.print_tb(e.__traceback__, file=sys.stderr) return f"Error downloading data: {e}" @@ -218,36 +247,32 @@ def processor_completed(future): return None -def dl_worker(volume_cache: VolumeCache, volume: int, parallel_fetch: bool = False): - packet = volume_cache.create_processing_data(volume, parallel=parallel_fetch) - - # Remove the volume from the cache after the packet is created - volume_cache.remove_volume(volume) - - return packet - - def process_worker_save_parallel( config: SliceOptions, - processing_data: tuple[np.ndarray | SharedNPArray, np.ndarray, np.ndarray, int], + volume: np.ndarray | SharedNPArray, + bounding_box: np.ndarray, + slice_indices: np.ndarray, + volume_index: int, slice_rects: np.ndarray, temporary_path: str = None, shared: bool = False ) -> tuple[int, dict[str, list[float]]]: - volume, bounding_box, slice_indices, volume_index = processing_data + start_total = time.perf_counter() + if shared: volume_data = volume.array() else: volume_data = volume durations = { + "initial_load": [], "generate_grid": [], "slice_volume": [], - "save": [], - "total_process": [], + "memmap_write": [], + "total_process": [] } - start_total = time.perf_counter() + durations["initial_load"].append(time.perf_counter() - start_total) # Generate a grid for each slice and stack them along the first axis start = time.perf_counter() @@ -263,12 +288,14 @@ def process_worker_save_parallel( ) durations["slice_volume"].append(time.perf_counter() - start) + start = time.perf_counter() # Save the slices to a previously created tiff file mmap = memmap(temporary_path) mmap[slice_indices] = slices mmap.flush() del mmap + durations["memmap_write"].append(time.perf_counter() - start) durations["total_process"].append(time.perf_counter() - start_total) return volume_index, durations, np.min(slices), np.max(slices) diff --git a/python/test/helpers/test_volume_cache.py b/python/test/helpers/test_volume_cache.py index 1fb2714..ffcaafa 100644 --- a/python/test/helpers/test_volume_cache.py +++ b/python/test/helpers/test_volume_cache.py @@ -183,25 +183,6 @@ def test_request_volume_for_slice(volume_cache): assert np.all(volume_data == volume_cache.volumes[1]) -def test_create_processing_data(volume_cache): - # Patch volume_cache.download to set the volume data - with patch.object(volume_cache, "download_volume") as mock_download: - - def mock_download_func(volume_index, bounding_box, parallel): - volume_cache.volumes[volume_index] = bounding_box.to_empty_volume() - - mock_download.side_effect = mock_download_func - - # Call the method - processing_data = volume_cache.create_processing_data(0) - - # Check the return values - assert processing_data[0] is not None - assert processing_data[1] == volume_cache.bounding_boxes[0] - assert processing_data[2] == [0] - assert processing_data[3] == 0 - - def test_get_mip_volume_sizes(mock_cloud_volume): with patch.object(mock_cloud_volume, "mip_volume_size") as mock_mip_volume_size: mock_mip_volume_size.return_value = (100, 100, 100)