@@ -71,7 +71,7 @@ def make_subcomm(self, new_targets):
7171 pass
7272
7373 @abstractmethod
74- def apply (self , func , args = None , kwargs = None , targets = None ):
74+ def apply (self , func , args = None , kwargs = None , targets = None , autoproxyize = False ):
7575 pass
7676
7777 @abstractmethod
@@ -816,6 +816,10 @@ def func_wrapper(func, apply_nonce, context_key, args, kwargs, autoproxyize):
816816 # default arguments
817817 args = () if args is None else args
818818 kwargs = {} if kwargs is None else kwargs
819+
820+ args = tuple (a .key if isinstance (a , DistArray ) else a for a in args )
821+ kwargs = {k : (v .key if isinstance (v , DistArray ) else v ) for k , v in kwargs .items ()}
822+
819823 apply_nonce = nonce ()
820824 wrapped_args = (func , apply_nonce , self .context_key , args , kwargs , autoproxyize )
821825
@@ -972,6 +976,10 @@ def apply(self, func, args=None, kwargs=None, targets=None, autoproxyize=False):
972976 # default arguments
973977 args = () if args is None else args
974978 kwargs = {} if kwargs is None else kwargs
979+
980+ args = tuple (a .key if isinstance (a , DistArray ) else a for a in args )
981+ kwargs = {k : (v .key if isinstance (v , DistArray ) else v ) for k , v in kwargs .items ()}
982+
975983 targets = self .targets if targets is None else targets
976984
977985 apply_nonce = nonce ()
0 commit comments