55import inspect
66import itertools
77import traceback
8+ import concurrent .futures
89from datetime import datetime
910from itertools import product , chain
10- from contextlib import asynccontextmanager , AsyncExitStack
11+ from contextlib import asynccontextmanager , AsyncExitStack , ExitStack
1112from typing import (
1213 AsyncIterator ,
1314 Dict ,
5859from ..util .cli .arg import Arg
5960from ..util .cli .cmd import CMD
6061from ..util .data import ignore_args , traverse_get
61- from ..util .asynchelper import context_stacker , aenter_stack
62+ from ..util .asynchelper import context_stacker , aenter_stack , concurrently
6263
6364from .log import LOGGER
6465
@@ -488,14 +489,19 @@ async def gather_inputs(
488489 if not gather [input_name ]:
489490 return
490491 # Generate all possible permutations of applicable inputs
491- for permutation in product (* list (gather .values ())):
492- # Create the parameter set
493- parameter_set = MemoryParameterSet (
494- MemoryParameterSetConfig (ctx = ctx , parameters = permutation )
492+ # Create the parameter set for each
493+ products = list (
494+ map (
495+ lambda permutation : MemoryParameterSet (
496+ MemoryParameterSetConfig (ctx = ctx , parameters = permutation )
497+ ),
498+ product (* list (gather .values ())),
495499 )
496- # Check if this permutation has been executed before
497- if not await rctx .exists (operation , parameter_set ):
498- # If not then return the permutation
500+ )
501+ # Check if each permutation has been executed before
502+ async for parameter_set , exists in rctx .exists (operation , * products ):
503+ # If not then yield the permutation
504+ if not exists :
499505 yield parameter_set
500506
501507
@@ -592,35 +598,63 @@ async def __aenter__(self) -> "MemoryRedundancyCheckerContext":
592598 async def __aexit__ (self , exc_type , exc_value , traceback ):
593599 await self .__stack .aclose ()
594600
601+ @staticmethod
602+ def _unique (instance_name : str , handle : str , * uids : str ) -> str :
603+ """
604+ SHA384 hash of the parameter set context handle as a string, the
605+ operation.instance_name, and the sorted list of input uuids.
606+ """
607+ uid_list = [instance_name , handle ] + sorted (uids )
608+ return hashlib .sha384 ("" .join (uid_list ).encode ("utf-8" )).hexdigest ()
609+
595610 async def unique (
596611 self , operation : Operation , parameter_set : BaseParameterSet
597612 ) -> str :
598613 """
599614 SHA384 hash of the parameter set context handle as a string, the
600615 operation.instance_name, and the sorted list of input uuids.
601616 """
602- uid_list = sorted (
603- map (
604- lambda x : x . uid ,
605- [ item async for item in parameter_set . inputs ()],
606- )
617+ uid_list = [
618+ operation . instance_name ,
619+ ( await parameter_set . ctx . handle ()). as_string () ,
620+ ] + sorted (
621+ [ item . origin . uid async for item in parameter_set . parameters ()]
607622 )
608- uid_list .insert (0 , (await parameter_set .ctx .handle ()).as_string ())
609- uid_list .insert (0 , operation .instance_name )
610- return hashlib .sha384 (", " .join (uid_list ).encode ("utf-8" )).hexdigest ()
623+ return hashlib .sha384 ("" .join (uid_list ).encode ("utf-8" )).hexdigest ()
624+
625+ async def _exists (self , coro ) -> bool :
626+ return bool (await self .kvctx .get (await coro ) == "\x01 " )
611627
612628 async def exists (
613- self , operation : Operation , parameter_set : BaseParameterSet
629+ self , operation : Operation , * parameter_sets : BaseParameterSet
614630 ) -> bool :
615- # self.logger.debug('checking parameter_set: %s', list(map(
616- # lambda p: p.value,
617- # [p async for p in parameter_set.parameters()])))
618- if (
619- await self .kvctx .get (await self .unique (operation , parameter_set ))
620- != "\x01 "
621- ):
622- return False
623- return True
631+ # TODO(p4) Run tests to choose an optimal threaded vs non-threaded value
632+ if len (parameter_sets ) < 4 :
633+ for parameter_set in parameter_sets :
634+ yield parameter_set , await self ._exists (
635+ self .unique (operation , parameter_set )
636+ )
637+ else :
638+ async for parameter_set , exists in concurrently (
639+ {
640+ asyncio .create_task (
641+ self ._exists (
642+ self .parent .loop .run_in_executor (
643+ self .parent .pool ,
644+ self ._unique ,
645+ operation .instance_name ,
646+ (await parameter_set .ctx .handle ()).as_string (),
647+ * [
648+ item .origin .uid
649+ async for item in parameter_set .parameters ()
650+ ],
651+ )
652+ )
653+ ): parameter_set
654+ for parameter_set in parameter_sets
655+ }
656+ ):
657+ yield parameter_set , exists
624658
625659 async def add (self , operation : Operation , parameter_set : BaseParameterSet ):
626660 # self.logger.debug('adding parameter_set: %s', list(map(
@@ -639,15 +673,28 @@ class MemoryRedundancyChecker(BaseRedundancyChecker, BaseMemoryDataFlowObject):
639673
640674 CONTEXT = MemoryRedundancyCheckerContext
641675
676+ def __init__ (self , config ):
677+ super ().__init__ (config )
678+ self .loop = None
679+ self .pool = None
680+ self .__pool = None
681+
642682 async def __aenter__ (self ) -> "MemoryRedundancyCheckerContext" :
643683 self .__stack = AsyncExitStack ()
684+ self .__exit_stack = ExitStack ()
685+ self .__exit_stack .__enter__ ()
644686 await self .__stack .__aenter__ ()
645687 self .key_value_store = await self .__stack .enter_async_context (
646688 self .config .key_value_store
647689 )
690+ self .loop = asyncio .get_event_loop ()
691+ self .pool = self .__exit_stack .enter_context (
692+ concurrent .futures .ThreadPoolExecutor ()
693+ )
648694 return self
649695
650696 async def __aexit__ (self , exc_type , exc_value , traceback ):
697+ self .__exit_stack .__exit__ (exc_type , exc_value , traceback )
651698 await self .__stack .__aexit__ (exc_type , exc_value , traceback )
652699
653700 @classmethod
@@ -831,7 +878,13 @@ async def run(
831878 operation .stage .value .upper (),
832879 operation .instance_name ,
833880 )
834- self .logger .debug ("Inputs: %s" , inputs )
881+ str_inputs = str (inputs )
882+ self .logger .debug (
883+ "Inputs: %s" ,
884+ str_inputs
885+ if len (str_inputs ) < 512
886+ else (str_inputs [:512 ] + "..." ),
887+ )
835888 self .logger .debug (
836889 "Conditions: %s" ,
837890 dict (
@@ -845,7 +898,13 @@ async def run(
845898 ),
846899 )
847900 outputs = await opctx .run (inputs )
848- self .logger .debug ("Output: %s" , outputs )
901+ str_outputs = str (outputs )
902+ self .logger .debug (
903+ "Outputs: %s" ,
904+ str_outputs
905+ if len (str_outputs ) < 512
906+ else (str_outputs [:512 ] + "..." ),
907+ )
849908 self .logger .debug ("---" )
850909 return outputs
851910
@@ -882,7 +941,9 @@ async def run_dispatch(
882941 expand = operation .expand
883942 else :
884943 expand = []
885- parents = [item async for item in parameter_set .inputs ()]
944+ parents = [
945+ item .origin async for item in parameter_set .parameters ()
946+ ]
886947 for key , output in outputs .items ():
887948 if not key in expand :
888949 output = [output ]
0 commit comments