@@ -82,6 +82,7 @@ class _ExtensibleCallable:
8282 func : Callable
8383 batch_func : Optional [Callable ]
8484 is_async : bool
85+ has_single_func : bool
8586
8687 def __call__ (self , * args , ** kwargs ):
8788 if self .is_async :
@@ -91,20 +92,26 @@ def __call__(self, *args, **kwargs):
9192
9293 async def _async_call (self , * args , ** kwargs ):
9394 try :
94- return await self .func (* args , ** kwargs )
95+ if self .has_single_func :
96+ return await self .func (* args , ** kwargs )
9597 except NotImplementedError :
96- if self .batch_func :
97- ret = await self .batch_func ([args ], [kwargs ])
98- return None if ret is None else ret [0 ]
99- raise
98+ self .has_single_func = False
99+
100+ if self .batch_func is not None :
101+ ret = await self .batch_func ([args ], [kwargs ])
102+ return None if ret is None else ret [0 ]
103+ raise NotImplementedError
100104
101105 def _sync_call (self , * args , ** kwargs ):
102106 try :
103- return self .func (* args , ** kwargs )
107+ if self .has_single_func :
108+ return self .func (* args , ** kwargs )
104109 except NotImplementedError :
105- if self .batch_func :
106- return self .batch_func ([args ], [kwargs ])[0 ]
107- raise
110+ self .has_single_func = False
111+
112+ if self .batch_func is not None :
113+ return self .batch_func ([args ], [kwargs ])[0 ]
114+ raise NotImplementedError
108115
109116
110117class _ExtensibleWrapper (_ExtensibleCallable ):
@@ -119,6 +126,7 @@ def __init__(
119126 self .batch_func = batch_func
120127 self .bind_func = bind_func
121128 self .is_async = is_async
129+ self .has_single_func = True
122130
123131 @staticmethod
124132 def delay (* args , ** kwargs ):
@@ -138,7 +146,7 @@ async def _async_batch(self, *delays):
138146 # will be more efficient
139147 if len (delays ) == 1 :
140148 d = delays [0 ]
141- return [await self .func (* d .args , ** d .kwargs )]
149+ return [await self ._async_call (* d .args , ** d .kwargs )]
142150 elif self .batch_func :
143151 args_list , kwargs_list = self ._gen_args_kwargs_list (delays )
144152 return await self .batch_func (args_list , kwargs_list )
@@ -184,6 +192,7 @@ def __init__(self, func: Callable):
184192 self .batch_func = None
185193 self .bind_func = build_args_binder (func , remove_self = True )
186194 self .is_async = asyncio .iscoroutinefunction (self .func )
195+ self .has_single_func = True
187196
188197 def batch (self , func : Callable ):
189198 self .batch_func = func
0 commit comments