Skip to content

Commit 91400a3

Browse files
committed
attempt no 1 refactoring background callbacks for async functions
1 parent 96df44e commit 91400a3

File tree

3 files changed

+270
-82
lines changed

3 files changed

+270
-82
lines changed

dash/_callback.py

Lines changed: 120 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -379,21 +379,10 @@ def initialize_context(args, kwargs, inputs_state_indices):
379379
False,
380380
)
381381

382-
def handle_long_callback(
383-
kwargs,
384-
long,
385-
long_key,
386-
callback_ctx,
387-
response,
388-
error_handler,
389-
func,
390-
func_args,
391-
func_kwargs,
392-
has_update=False,
393-
):
382+
def get_callback_manager(kwargs):
394383
"""Set up the long callback and manage jobs."""
395384
callback_manager = long.get(
396-
"manager", kwargs.pop("long_callback_manager", None)
385+
"manager", kwargs.get("long_callback_manager", None)
397386
)
398387
if not callback_manager:
399388
raise MissingLongCallbackManagerError(
@@ -405,63 +394,120 @@ def handle_long_callback(
405394
" and store results on redis.\n"
406395
)
407396

408-
progress_outputs = long.get("progress")
409-
cache_key = flask.request.args.get("cacheKey")
410-
job_id = flask.request.args.get("job")
411397
old_job = flask.request.args.getlist("oldJob")
412398

413-
current_key = callback_manager.build_cache_key(
399+
if old_job:
400+
for job in old_job:
401+
callback_manager.terminate_job(job)
402+
403+
return callback_manager
404+
405+
def setup_long_callback(
406+
kwargs,
407+
long,
408+
long_key,
409+
func,
410+
func_args,
411+
func_kwargs,
412+
):
413+
"""Set up the long callback and manage jobs."""
414+
callback_manager = get_callback_manager(kwargs)
415+
416+
progress_outputs = long.get("progress")
417+
418+
cache_key = callback_manager.build_cache_key(
414419
func,
415420
# Inputs provided as dict is kwargs.
416421
func_args if func_args else func_kwargs,
417422
long.get("cache_args_to_ignore", []),
418423
)
419424

420-
if old_job:
421-
for job in old_job:
422-
callback_manager.terminate_job(job)
425+
job_fn = callback_manager.func_registry.get(long_key)
423426

424-
if not cache_key:
425-
cache_key = current_key
427+
ctx_value = AttributeDict(**context_value.get())
428+
ctx_value.ignore_register_page = True
429+
ctx_value.pop("background_callback_manager")
430+
ctx_value.pop("dash_response")
426431

427-
job_fn = callback_manager.func_registry.get(long_key)
432+
job = callback_manager.call_job_fn(
433+
cache_key,
434+
job_fn,
435+
func_args if func_args else func_kwargs,
436+
ctx_value,
437+
)
428438

429-
ctx_value = AttributeDict(**context_value.get())
430-
ctx_value.ignore_register_page = True
431-
ctx_value.pop("background_callback_manager")
432-
ctx_value.pop("dash_response")
439+
data = {
440+
"cacheKey": cache_key,
441+
"job": job,
442+
}
433443

434-
job = callback_manager.call_job_fn(
435-
cache_key,
436-
job_fn,
437-
func_args if func_args else func_kwargs,
438-
ctx_value,
439-
)
444+
cancel = long.get("cancel")
445+
if cancel:
446+
data["cancel"] = cancel
440447

441-
data = {
442-
"cacheKey": cache_key,
443-
"job": job,
448+
progress_default = long.get("progressDefault")
449+
if progress_default:
450+
data["progressDefault"] = {
451+
str(o): x for o, x in zip(progress_outputs, progress_default)
444452
}
453+
return to_json(data)
445454

446-
cancel = long.get("cancel")
447-
if cancel:
448-
data["cancel"] = cancel
455+
def progress_long_callback(response, callback_manager):
456+
progress_outputs = long.get("progress")
457+
cache_key = flask.request.args.get("cacheKey")
449458

450-
progress_default = long.get("progressDefault")
451-
if progress_default:
452-
data["progressDefault"] = {
453-
str(o): x for o, x in zip(progress_outputs, progress_default)
454-
}
455-
return to_json(data), True, has_update
456459
if progress_outputs:
457460
# Get the progress before the result as it would be erased after the results.
458461
progress = callback_manager.get_progress(cache_key)
459462
if progress:
460-
response["progress"] = {
461-
str(x): progress[i] for i, x in enumerate(progress_outputs)
462-
}
463+
response.update(
464+
{
465+
"progress": {
466+
str(x): progress[i] for i, x in enumerate(progress_outputs)
467+
}
468+
}
469+
)
470+
471+
def update_long_callback(error_handler, callback_ctx, response, kwargs):
472+
"""Set up the long callback and manage jobs."""
473+
callback_manager = get_callback_manager(kwargs)
474+
475+
cache_key = flask.request.args.get("cacheKey")
476+
job_id = flask.request.args.get("job")
463477

464478
output_value = callback_manager.get_result(cache_key, job_id)
479+
480+
progress_long_callback(response, callback_manager)
481+
482+
return handle_rest_long_callback(
483+
output_value, callback_manager, response, error_handler, callback_ctx
484+
)
485+
486+
async def async_update_long_callback(error_handler, callback_ctx, response, kwargs):
487+
"""Set up the long callback and manage jobs."""
488+
callback_manager = get_callback_manager(kwargs)
489+
490+
cache_key = flask.request.args.get("cacheKey")
491+
job_id = flask.request.args.get("job")
492+
493+
output_value = await callback_manager.async_get_result(cache_key, job_id)
494+
495+
progress_long_callback(response, callback_manager)
496+
497+
return handle_rest_long_callback(
498+
output_value, callback_manager, response, error_handler, callback_ctx
499+
)
500+
501+
def handle_rest_long_callback(
502+
output_value,
503+
callback_manager,
504+
response,
505+
error_handler,
506+
callback_ctx,
507+
has_update=False,
508+
):
509+
cache_key = flask.request.args.get("cacheKey")
510+
job_id = flask.request.args.get("job")
465511
# Must get job_running after get_result since get_results terminates it.
466512
job_running = callback_manager.job_running(job_id)
467513
if not job_running and output_value is callback_manager.UNDEFINED:
@@ -501,8 +547,8 @@ def handle_long_callback(
501547
has_update = True
502548

503549
if output_value is callback_manager.UNDEFINED:
504-
return to_json(response), True, has_update
505-
return output_value, False, has_update
550+
return to_json(response), has_update, True
551+
return output_value, has_update, False
506552

507553
def prepare_response(
508554
output_value,
@@ -577,7 +623,6 @@ def prepare_response(
577623

578624
# pylint: disable=too-many-locals
579625
def wrap_func(func):
580-
581626
if long is not None:
582627
long_key = BaseLongCallbackManager.register_func(
583628
func,
@@ -604,16 +649,18 @@ def add_context(*args, **kwargs):
604649

605650
try:
606651
if long is not None:
607-
output_value, skip, has_update = handle_long_callback(
608-
kwargs,
609-
long,
610-
long_key,
611-
callback_ctx,
612-
response,
613-
error_handler,
614-
func,
615-
func_args,
616-
func_kwargs,
652+
if not flask.request.args.get("cacheKey"):
653+
return setup_long_callback(
654+
kwargs,
655+
long,
656+
long_key,
657+
func,
658+
func_args,
659+
func_kwargs,
660+
)
661+
662+
output_value, has_update, skip = update_long_callback(
663+
error_handler, callback_ctx, response, kwargs
617664
)
618665
if skip:
619666
return output_value
@@ -666,16 +713,17 @@ async def async_add_context(*args, **kwargs):
666713

667714
try:
668715
if long is not None:
669-
output_value, skip, has_update = handle_long_callback(
670-
kwargs,
671-
long,
672-
long_key,
673-
callback_ctx,
674-
response,
675-
error_handler,
676-
func,
677-
func_args,
678-
func_kwargs,
716+
if not flask.request.args.get("cacheKey"):
717+
return setup_long_callback(
718+
kwargs,
719+
long,
720+
long_key,
721+
func,
722+
func_args,
723+
func_kwargs,
724+
)
725+
output_value, has_update, skip = update_long_callback(
726+
error_handler, callback_ctx, response, kwargs
679727
)
680728
if skip:
681729
return output_value
@@ -695,7 +743,7 @@ async def async_add_context(*args, **kwargs):
695743

696744
prepare_response(
697745
output_value,
698-
has_output,
746+
output_spec,
699747
multi,
700748
response,
701749
callback_ctx,

dash/long_callback/managers/celery_manager.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from dash.exceptions import PreventUpdate
1010
from dash.long_callback._proxy_set_props import ProxySetProps
1111
from dash.long_callback.managers import BaseLongCallbackManager
12+
import asyncio
13+
from functools import partial
1214

1315

1416
class CeleryManager(BaseLongCallbackManager):
@@ -90,7 +92,14 @@ def clear_cache_entry(self, key):
9092
self.handle.backend.delete(key)
9193

9294
def call_job_fn(self, key, job_fn, args, context):
93-
task = job_fn.delay(key, self._make_progress_key(key), args, context)
95+
if asyncio.iscoroutinefunction(job_fn):
96+
# pylint: disable-next=import-outside-toplevel,no-name-in-module,import-error
97+
from asgiref.sync import async_to_sync
98+
99+
new_job_fun = async_to_sync(job_fn)
100+
task = new_job_fun.delay(key, self._make_progress_key(key), args, context)
101+
else:
102+
task = job_fn.delay(key, self._make_progress_key(key), args, context)
94103
return task.task_id
95104

96105
def get_progress(self, key):
@@ -197,7 +206,49 @@ def run():
197206
result_key, json.dumps(user_callback_output, cls=PlotlyJSONEncoder)
198207
)
199208

200-
ctx.run(run)
209+
async def async_run():
210+
c = AttributeDict(**context)
211+
c.ignore_register_page = False
212+
c.updated_props = ProxySetProps(_set_props)
213+
context_value.set(c)
214+
errored = False
215+
try:
216+
if isinstance(user_callback_args, dict):
217+
user_callback_output = await fn(
218+
*maybe_progress, **user_callback_args
219+
)
220+
elif isinstance(user_callback_args, (list, tuple)):
221+
user_callback_output = await fn(
222+
*maybe_progress, *user_callback_args
223+
)
224+
else:
225+
user_callback_output = await fn(*maybe_progress, user_callback_args)
226+
except PreventUpdate:
227+
errored = True
228+
cache.set(result_key, {"_dash_no_update": "_dash_no_update"})
229+
except Exception as err: # pylint: disable=broad-except
230+
errored = True
231+
cache.set(
232+
result_key,
233+
{
234+
"long_callback_error": {
235+
"msg": str(err),
236+
"tb": traceback.format_exc(),
237+
}
238+
},
239+
)
240+
if asyncio.iscoroutine(user_callback_output):
241+
user_callback_output = await user_callback_output
242+
if not errored:
243+
cache.set(
244+
result_key, json.dumps(user_callback_output, cls=PlotlyJSONEncoder)
245+
)
246+
247+
if asyncio.iscoroutinefunction(fn):
248+
func = partial(ctx.run, async_run)
249+
asyncio.run(func())
250+
else:
251+
ctx.run(run)
201252

202253
return job_fn
203254

0 commit comments

Comments
 (0)