diff --git a/python/fate_arch/computing/eggroll/_table.py b/python/fate_arch/computing/eggroll/_table.py index a33af09302..5ba844d403 100644 --- a/python/fate_arch/computing/eggroll/_table.py +++ b/python/fate_arch/computing/eggroll/_table.py @@ -21,12 +21,13 @@ from fate_arch.common import log from fate_arch.common.profile import computing_profile from fate_arch.computing._type import ComputingEngine +import random +import sys LOGGER = log.getLogger() class Table(CTableABC): - def __init__(self, rp): self._rp = rp self._engine = ComputingEngine.EGGROLL @@ -49,18 +50,30 @@ def save(self, address, partitions, schema: dict, **kwargs): options = kwargs.get("options", {}) from fate_arch.common.address import EggRollAddress from fate_arch.storage import EggRollStoreType + if isinstance(address, EggRollAddress): - options["store_type"] = kwargs.get("store_type", EggRollStoreType.ROLLPAIR_LMDB) - self._rp.save_as(name=address.name, namespace=address.namespace, partition=partitions, options=options) + options["store_type"] = kwargs.get( + "store_type", EggRollStoreType.ROLLPAIR_LMDB + ) + self._rp.save_as( + name=address.name, + namespace=address.namespace, + partition=partitions, + options=options, + ) schema.update(self.schema) return from fate_arch.common.address import PathAddress + if isinstance(address, PathAddress): from fate_arch.computing.non_distributed import LocalData + return LocalData(address.path) - raise NotImplementedError(f"address type {type(address)} not supported with eggroll backend") + raise NotImplementedError( + f"address type {type(address)} not supported with eggroll backend" + ) @computing_profile def collect(self, **kwargs) -> list: @@ -95,14 +108,22 @@ def applyPartitions(self, func): return Table(self._rp.collapse_partitions(func)) @computing_profile - def mapPartitions(self, func, use_previous_behavior=True, preserves_partitioning=False, **kwargs): + def mapPartitions( + self, func, use_previous_behavior=True, preserves_partitioning=False, **kwargs + ): if use_previous_behavior is True: - LOGGER.warning(f"please use `applyPartitions` instead of `mapPartitions` " - f"if the previous behavior was expected. " - f"The previous behavior will not work in future") + LOGGER.warning( + f"please use `applyPartitions` instead of `mapPartitions` " + f"if the previous behavior was expected. " + f"The previous behavior will not work in future" + ) return self.applyPartitions(func) - return Table(self._rp.map_partitions(func, options={"shuffle": not preserves_partitioning})) + return Table( + self._rp.map_partitions( + func, options={"shuffle": not preserves_partitioning} + ) + ) @computing_profile def mapReducePartitions(self, mapper, reducer, **kwargs): @@ -110,14 +131,18 @@ def mapReducePartitions(self, mapper, reducer, **kwargs): @computing_profile def mapPartitionsWithIndex(self, func, preserves_partitioning=False, **kwargs): - return Table(self._rp.map_partitions_with_index(func, options={"shuffle": not preserves_partitioning})) + return Table( + self._rp.map_partitions_with_index( + func, options={"shuffle": not preserves_partitioning} + ) + ) @computing_profile def reduce(self, func, **kwargs): return self._rp.reduce(func) @computing_profile - def join(self, other: 'Table', func, **kwargs): + def join(self, other: "Table", func, **kwargs): return Table(self._rp.join(other._rp, func=func)) @computing_profile @@ -125,35 +150,25 @@ def glom(self, **kwargs): return Table(self._rp.glom()) @computing_profile - def sample(self, *, fraction: typing.Optional[float] = None, num: typing.Optional[int] = None, seed=None): + def sample( + self, + *, + fraction: typing.Optional[float] = None, + num: typing.Optional[int] = None, + seed=None, + ): if fraction is not None: return Table(self._rp.sample(fraction=fraction, seed=seed)) if num is not None: - total = self._rp.count() - if num > total: - raise ValueError(f"not enough data to sample, own {total} but required {num}") - - frac = num / float(total) - while True: - sampled_table = self._rp.sample(fraction=frac, seed=seed) - sampled_count = sampled_table.count() - if sampled_count < num: - frac *= 1.1 - else: - break + return _exactly_sample(self, num, seed) - if sampled_count > num: - drops = sampled_table.take(sampled_count - num) - for k, v in drops: - sampled_table.delete(k) - - return Table(sampled_table) - - raise ValueError(f"exactly one of `fraction` or `num` required, fraction={fraction}, num={num}") + raise ValueError( + f"exactly one of `fraction` or `num` required, fraction={fraction}, num={num}" + ) @computing_profile - def subtractByKey(self, other: 'Table', **kwargs): + def subtractByKey(self, other: "Table", **kwargs): return Table(self._rp.subtract_by_key(other._rp)) @computing_profile @@ -161,7 +176,7 @@ def filter(self, func, **kwargs): return Table(self._rp.filter(func)) @computing_profile - def union(self, other: 'Table', func=lambda v1, v2: v1, **kwargs): + def union(self, other: "Table", func=lambda v1, v2: v1, **kwargs): return Table(self._rp.union(other._rp, func=func)) @computing_profile @@ -169,3 +184,60 @@ def flatMap(self, func, **kwargs): flat_map = self._rp.flat_map(func) shuffled = flat_map.map(lambda k, v: (k, v)) # trigger shuffle return Table(shuffled) + + +def _exactly_sample(table: Table, num, seed): + from scipy.stats import hypergeom + + split_size = list( + table.mapPartitionsWithIndex(lambda s, it: [(s, sum(1 for _ in it))]).collect() + ) + total = sum(v for _, v in split_size) + + if num > total: + raise ValueError(f"not enough data to sample, own {total} but required {num}") + # random the size of each split + sampled_size = {} + for split, size in split_size: + if size <= 0: + sampled_size[split] = 0 + else: + sampled_size[split] = hypergeom.rvs(M=total, n=size, N=num) + total = total - size + num = num - sampled_size[split] + + return table.mapPartitionsWithIndex( + func=_ReservoirSample(split_sample_size=sampled_size, seed=seed).func, + shuffle=False, + ) + + +class _ReservoirSample: + def __init__(self, split_sample_size, seed): + self._split_sample_size = split_sample_size + self._counter = 0 + self._sample = [] + self._seed = seed if seed is not None else random.randint(0, sys.maxsize) + self._random = None + + def initRandomGenerator(self, split): + self._random = random.Random(self._seed ^ split) + + # mixing because the initial seeds are close to each other + for _ in range(10): + self._random.randint(0, 1) + + def func(self, split, iterator): + self.initRandomGenerator(split) + size = self._split_sample_size[split] + for obj in iterator: + self._counter += 1 + if len(self._sample) < size: + self._sample.append(obj) + continue + + randint = self._random.randint(1, self._counter) + if randint <= size: + self._sample[randint - 1] = obj + + return self._sample diff --git a/python/fate_arch/computing/standalone/_table.py b/python/fate_arch/computing/standalone/_table.py index 77f5967be1..3f400c81ea 100644 --- a/python/fate_arch/computing/standalone/_table.py +++ b/python/fate_arch/computing/standalone/_table.py @@ -15,6 +15,8 @@ # import itertools +import random +import sys import typing from fate_arch.abc import CTableABC @@ -156,27 +158,7 @@ def sample( return Table(self._table.sample(fraction=fraction, seed=seed)) if num is not None: - total = self._table.count() - if num > total: - raise ValueError( - f"not enough data to sample, own {total} but required {num}" - ) - - frac = num / float(total) - while True: - sampled_table = self._table.sample(fraction=frac, seed=seed) - sampled_count = sampled_table.count() - if sampled_count < num: - frac += 0.1 - else: - break - - if sampled_count > num: - drops = sampled_table.take(sampled_count - num) - for k, v in drops: - sampled_table.delete(k) - - return Table(sampled_table) + return _exactly_sample(self, num, seed) raise ValueError( f"exactly one of `fraction` or `num` required, fraction={fraction}, num={num}" @@ -197,3 +179,60 @@ def subtractByKey(self, other: "Table"): @computing_profile def union(self, other: "Table", func=lambda v1, v2: v1): return Table(self._table.union(other._table, func)) + + +def _exactly_sample(table: Table, num, seed): + from scipy.stats import hypergeom + + split_size = list( + table.mapPartitionsWithIndex(lambda s, it: [(s, sum(1 for _ in it))]).collect() + ) + total = sum(v for _, v in split_size) + + if num > total: + raise ValueError(f"not enough data to sample, own {total} but required {num}") + # random the size of each split + sampled_size = {} + for split, size in split_size: + if size <= 0: + sampled_size[split] = 0 + else: + sampled_size[split] = hypergeom.rvs(M=total, n=size, N=num) + total = total - size + num = num - sampled_size[split] + + return table.mapPartitionsWithIndex( + func=_ReservoirSample(split_sample_size=sampled_size, seed=seed).func, + shuffle=False, + ) + + +class _ReservoirSample: + def __init__(self, split_sample_size, seed): + self._split_sample_size = split_sample_size + self._counter = 0 + self._sample = [] + self._seed = seed if seed is not None else random.randint(0, sys.maxsize) + self._random = None + + def initRandomGenerator(self, split): + self._random = random.Random(self._seed ^ split) + + # mixing because the initial seeds are close to each other + for _ in range(10): + self._random.randint(0, 1) + + def func(self, split, iterator): + self.initRandomGenerator(split) + size = self._split_sample_size[split] + for obj in iterator: + self._counter += 1 + if len(self._sample) < size: + self._sample.append(obj) + continue + + randint = self._random.randint(1, self._counter) + if randint <= size: + self._sample[randint - 1] = obj + + return self._sample