Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 55da27b

Browse files
author
John Andersen
authored
df: memory: Fix redundancy checker
Signed-off-by: John Andersen <[email protected]>
1 parent 652a974 commit 55da27b

File tree

3 files changed

+98
-30
lines changed

3 files changed

+98
-30
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
### Fixed
1313
- DataFlows with multiple possibilities for a source for an input, now correctly
1414
look through all possible sources instead of just the first one.
15+
- DataFlow MemoryRedundancyCheckerContext was using all inputs in an input set
16+
and all their ancestors to check redundancy (a hold over from pre uid days).
17+
It now correctly only uses the inputs in the parameter set. This fixes a major
18+
performance issue.
1519

1620
## [0.3.0] - 2019-10-26
1721
### Added

dffml/df/memory.py

Lines changed: 91 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import inspect
66
import itertools
77
import traceback
8+
import concurrent.futures
89
from datetime import datetime
910
from itertools import product, chain
10-
from contextlib import asynccontextmanager, AsyncExitStack
11+
from contextlib import asynccontextmanager, AsyncExitStack, ExitStack
1112
from typing import (
1213
AsyncIterator,
1314
Dict,
@@ -58,7 +59,7 @@
5859
from ..util.cli.arg import Arg
5960
from ..util.cli.cmd import CMD
6061
from ..util.data import ignore_args, traverse_get
61-
from ..util.asynchelper import context_stacker, aenter_stack
62+
from ..util.asynchelper import context_stacker, aenter_stack, concurrently
6263

6364
from .log import LOGGER
6465

@@ -488,14 +489,19 @@ async def gather_inputs(
488489
if not gather[input_name]:
489490
return
490491
# Generate all possible permutations of applicable inputs
491-
for permutation in product(*list(gather.values())):
492-
# Create the parameter set
493-
parameter_set = MemoryParameterSet(
494-
MemoryParameterSetConfig(ctx=ctx, parameters=permutation)
492+
# Create the parameter set for each
493+
products = list(
494+
map(
495+
lambda permutation: MemoryParameterSet(
496+
MemoryParameterSetConfig(ctx=ctx, parameters=permutation)
497+
),
498+
product(*list(gather.values())),
495499
)
496-
# Check if this permutation has been executed before
497-
if not await rctx.exists(operation, parameter_set):
498-
# If not then return the permutation
500+
)
501+
# Check if each permutation has been executed before
502+
async for parameter_set, exists in rctx.exists(operation, *products):
503+
# If not then yield the permutation
504+
if not exists:
499505
yield parameter_set
500506

501507

@@ -592,35 +598,63 @@ async def __aenter__(self) -> "MemoryRedundancyCheckerContext":
592598
async def __aexit__(self, exc_type, exc_value, traceback):
593599
await self.__stack.aclose()
594600

601+
@staticmethod
602+
def _unique(instance_name: str, handle: str, *uids: str) -> str:
603+
"""
604+
SHA384 hash of the parameter set context handle as a string, the
605+
operation.instance_name, and the sorted list of input uuids.
606+
"""
607+
uid_list = [instance_name, handle] + sorted(uids)
608+
return hashlib.sha384("".join(uid_list).encode("utf-8")).hexdigest()
609+
595610
async def unique(
596611
self, operation: Operation, parameter_set: BaseParameterSet
597612
) -> str:
598613
"""
599614
SHA384 hash of the parameter set context handle as a string, the
600615
operation.instance_name, and the sorted list of input uuids.
601616
"""
602-
uid_list = sorted(
603-
map(
604-
lambda x: x.uid,
605-
[item async for item in parameter_set.inputs()],
606-
)
617+
uid_list = [
618+
operation.instance_name,
619+
(await parameter_set.ctx.handle()).as_string(),
620+
] + sorted(
621+
[item.origin.uid async for item in parameter_set.parameters()]
607622
)
608-
uid_list.insert(0, (await parameter_set.ctx.handle()).as_string())
609-
uid_list.insert(0, operation.instance_name)
610-
return hashlib.sha384(", ".join(uid_list).encode("utf-8")).hexdigest()
623+
return hashlib.sha384("".join(uid_list).encode("utf-8")).hexdigest()
624+
625+
async def _exists(self, coro) -> bool:
626+
return bool(await self.kvctx.get(await coro) == "\x01")
611627

612628
async def exists(
613-
self, operation: Operation, parameter_set: BaseParameterSet
629+
self, operation: Operation, *parameter_sets: BaseParameterSet
614630
) -> bool:
615-
# self.logger.debug('checking parameter_set: %s', list(map(
616-
# lambda p: p.value,
617-
# [p async for p in parameter_set.parameters()])))
618-
if (
619-
await self.kvctx.get(await self.unique(operation, parameter_set))
620-
!= "\x01"
621-
):
622-
return False
623-
return True
631+
# TODO(p4) Run tests to choose an optimal threaded vs non-threaded value
632+
if len(parameter_sets) < 4:
633+
for parameter_set in parameter_sets:
634+
yield parameter_set, await self._exists(
635+
self.unique(operation, parameter_set)
636+
)
637+
else:
638+
async for parameter_set, exists in concurrently(
639+
{
640+
asyncio.create_task(
641+
self._exists(
642+
self.parent.loop.run_in_executor(
643+
self.parent.pool,
644+
self._unique,
645+
operation.instance_name,
646+
(await parameter_set.ctx.handle()).as_string(),
647+
*[
648+
item.origin.uid
649+
async for item in parameter_set.parameters()
650+
],
651+
)
652+
)
653+
): parameter_set
654+
for parameter_set in parameter_sets
655+
}
656+
):
657+
yield parameter_set, exists
624658

625659
async def add(self, operation: Operation, parameter_set: BaseParameterSet):
626660
# self.logger.debug('adding parameter_set: %s', list(map(
@@ -639,15 +673,28 @@ class MemoryRedundancyChecker(BaseRedundancyChecker, BaseMemoryDataFlowObject):
639673

640674
CONTEXT = MemoryRedundancyCheckerContext
641675

676+
def __init__(self, config):
677+
super().__init__(config)
678+
self.loop = None
679+
self.pool = None
680+
self.__pool = None
681+
642682
async def __aenter__(self) -> "MemoryRedundancyCheckerContext":
643683
self.__stack = AsyncExitStack()
684+
self.__exit_stack = ExitStack()
685+
self.__exit_stack.__enter__()
644686
await self.__stack.__aenter__()
645687
self.key_value_store = await self.__stack.enter_async_context(
646688
self.config.key_value_store
647689
)
690+
self.loop = asyncio.get_event_loop()
691+
self.pool = self.__exit_stack.enter_context(
692+
concurrent.futures.ThreadPoolExecutor()
693+
)
648694
return self
649695

650696
async def __aexit__(self, exc_type, exc_value, traceback):
697+
self.__exit_stack.__exit__(exc_type, exc_value, traceback)
651698
await self.__stack.__aexit__(exc_type, exc_value, traceback)
652699

653700
@classmethod
@@ -831,7 +878,13 @@ async def run(
831878
operation.stage.value.upper(),
832879
operation.instance_name,
833880
)
834-
self.logger.debug("Inputs: %s", inputs)
881+
str_inputs = str(inputs)
882+
self.logger.debug(
883+
"Inputs: %s",
884+
str_inputs
885+
if len(str_inputs) < 512
886+
else (str_inputs[:512] + "..."),
887+
)
835888
self.logger.debug(
836889
"Conditions: %s",
837890
dict(
@@ -845,7 +898,13 @@ async def run(
845898
),
846899
)
847900
outputs = await opctx.run(inputs)
848-
self.logger.debug("Output: %s", outputs)
901+
str_outputs = str(outputs)
902+
self.logger.debug(
903+
"Outputs: %s",
904+
str_outputs
905+
if len(str_outputs) < 512
906+
else (str_outputs[:512] + "..."),
907+
)
849908
self.logger.debug("---")
850909
return outputs
851910

@@ -882,7 +941,9 @@ async def run_dispatch(
882941
expand = operation.expand
883942
else:
884943
expand = []
885-
parents = [item async for item in parameter_set.inputs()]
944+
parents = [
945+
item.origin async for item in parameter_set.parameters()
946+
]
886947
for key, output in outputs.items():
887948
if not key in expand:
888949
output = [output]

dffml/df/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,9 @@ def __post_init__(self):
417417
# Determine the dataflow if not given
418418
if self.flow is None:
419419
self.flow = self.auto_flow()
420+
self.update_by_origin()
421+
422+
def update_by_origin(self):
420423
# Create by_origin which maps operation instance names to the sources
421424
self.by_origin = {}
422425
for instance_name, input_flow in self.flow.items():

0 commit comments

Comments
 (0)