|
| 1 | +from __future__ import annotations |
| 2 | +import logging |
| 3 | +from threading import Lock |
| 4 | +import time |
| 5 | +from types import SimpleNamespace |
| 6 | +import dask |
| 7 | +from dask.distributed import Client, LocalCluster, progress |
| 8 | +from aperturedb.Connector import Connector |
| 9 | + |
| 10 | +import multiprocessing as mp |
| 11 | + |
| 12 | +from aperturedb.Stats import Stats |
| 13 | + |
| 14 | +logger = logging.getLogger(__name__) |
| 15 | + |
| 16 | + |
| 17 | +class DaskManager: |
| 18 | + def __init__(self, num_workers: int = -1): |
| 19 | + self.__num_workers = num_workers |
| 20 | + |
| 21 | + def run(self, db: Connector, generator, batchsize, stats): |
| 22 | + def process(df, host, port, session): |
| 23 | + metrics = Stats() |
| 24 | + # Dask reads data in partitions, and the first partition is of 2 rows, with all |
| 25 | + # values as 'foo'. This is for sampling the column names and types. Should not process |
| 26 | + # those rows. |
| 27 | + if len(df) == 2: |
| 28 | + if df.iloc[0, 0] == "foo": |
| 29 | + return |
| 30 | + count = 0 |
| 31 | + try: |
| 32 | + shared_data = SimpleNamespace() |
| 33 | + shared_data.session = session |
| 34 | + shared_data.lock = Lock() |
| 35 | + db = Connector(host=host, port=port, shared_data=shared_data) |
| 36 | + except Exception as e: |
| 37 | + logger.exception(e) |
| 38 | + from aperturedb.ParallelLoader import ParallelLoader |
| 39 | + loader = ParallelLoader(db) |
| 40 | + for i in range(0, len(df), batchsize): |
| 41 | + end = min(i + batchsize, len(df)) |
| 42 | + slice = df[i:end] |
| 43 | + data = generator.__class__(filename="", df=slice) |
| 44 | + loader.ingest(generator=data, batchsize=len( |
| 45 | + slice), numthreads=1, stats=False) |
| 46 | + count += 1 |
| 47 | + metrics.times_arr.extend(loader.times_arr) |
| 48 | + metrics.error_counter += loader.error_counter |
| 49 | + |
| 50 | + return metrics |
| 51 | + |
| 52 | + # The -1 magic number is to use as many 90% of the cores (1 worker per core). |
| 53 | + # This can be overridden by the user. |
| 54 | + # Create a pool of workers. |
| 55 | + # TODO: see if the same pool can be reused for multiple tasks. |
| 56 | + workers = self.__num_workers if self.__num_workers != \ |
| 57 | + -1 else int(0.9 * mp.cpu_count()) |
| 58 | + with LocalCluster(n_workers=workers) as cluster, Client(cluster) as client: |
| 59 | + dask.config.set(scheduler="distributed") |
| 60 | + start_time = time.time() |
| 61 | + # Passing DB as an argument to function is not supported by Dask, |
| 62 | + # so we pass session and host/port instead. |
| 63 | + computation = generator.df.map_partitions( |
| 64 | + process, |
| 65 | + db.host, |
| 66 | + db.port, |
| 67 | + db.shared_data.session) |
| 68 | + computation = computation.persist() |
| 69 | + if stats: |
| 70 | + progress(computation) |
| 71 | + results = computation.compute() |
| 72 | + return results, time.time() - start_time |
0 commit comments