Skip to content

Commit e24e094

Browse files
committed
adding new use_async attribute to Dash and having callbacks and layouts reference this in order to determine whether or not they should load as async functions.
1 parent e1002d5 commit e24e094

File tree

2 files changed

+488
-63
lines changed

2 files changed

+488
-63
lines changed

dash/_callback.py

Lines changed: 233 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,17 @@
4040
from ._callback_context import context_value
4141

4242

43-
async def _invoke_callback(func, *args, **kwargs):
43+
async def _async_invoke_callback(func, *args, **kwargs):
4444
# Check if the function is a coroutine function
4545
if asyncio.iscoroutinefunction(func):
4646
return await func(*args, **kwargs)
4747
else:
4848
# If the function is not a coroutine, call it directly
4949
return func(*args, **kwargs)
5050

51+
def _invoke_callback(func, *args, **kwargs):
52+
return func(*args, **kwargs)
53+
5154

5255
class NoUpdate:
5356
def to_plotly_json(self): # pylint: disable=no-self-use
@@ -321,6 +324,7 @@ def register_callback(
321324
manager = _kwargs.get("manager")
322325
running = _kwargs.get("running")
323326
on_error = _kwargs.get("on_error")
327+
use_async = _kwargs.get("app_use_async")
324328
if running is not None:
325329
if not isinstance(running[0], (list, tuple)):
326330
running = [running]
@@ -359,7 +363,229 @@ def wrap_func(func):
359363
)
360364

361365
@wraps(func)
362-
async def add_context(*args, **kwargs):
366+
async def async_add_context(*args, **kwargs):
367+
output_spec = kwargs.pop("outputs_list")
368+
app_callback_manager = kwargs.pop("long_callback_manager", None)
369+
370+
callback_ctx = kwargs.pop(
371+
"callback_context", AttributeDict({"updated_props": {}})
372+
)
373+
app = kwargs.pop("app", None)
374+
callback_manager = long and long.get("manager", app_callback_manager)
375+
error_handler = on_error or kwargs.pop("app_on_error", None)
376+
original_packages = set(ComponentRegistry.registry)
377+
378+
if has_output:
379+
_validate.validate_output_spec(insert_output, output_spec, Output)
380+
381+
context_value.set(callback_ctx)
382+
383+
func_args, func_kwargs = _validate.validate_and_group_input_args(
384+
args, inputs_state_indices
385+
)
386+
387+
response: dict = {"multi": True}
388+
has_update = False
389+
390+
if long is not None:
391+
if not callback_manager:
392+
raise MissingLongCallbackManagerError(
393+
"Running `long` callbacks requires a manager to be installed.\n"
394+
"Available managers:\n"
395+
"- Diskcache (`pip install dash[diskcache]`) to run callbacks in a separate Process"
396+
" and store results on the local filesystem.\n"
397+
"- Celery (`pip install dash[celery]`) to run callbacks in a celery worker"
398+
" and store results on redis.\n"
399+
)
400+
401+
progress_outputs = long.get("progress")
402+
cache_key = flask.request.args.get("cacheKey")
403+
job_id = flask.request.args.get("job")
404+
old_job = flask.request.args.getlist("oldJob")
405+
406+
current_key = callback_manager.build_cache_key(
407+
func,
408+
# Inputs provided as dict is kwargs.
409+
func_args if func_args else func_kwargs,
410+
long.get("cache_args_to_ignore", []),
411+
)
412+
413+
if old_job:
414+
for job in old_job:
415+
callback_manager.terminate_job(job)
416+
417+
if not cache_key:
418+
cache_key = current_key
419+
420+
job_fn = callback_manager.func_registry.get(long_key)
421+
422+
ctx_value = AttributeDict(**context_value.get())
423+
ctx_value.ignore_register_page = True
424+
ctx_value.pop("background_callback_manager")
425+
ctx_value.pop("dash_response")
426+
427+
job = callback_manager.call_job_fn(
428+
cache_key,
429+
job_fn,
430+
func_args if func_args else func_kwargs,
431+
ctx_value,
432+
)
433+
434+
data = {
435+
"cacheKey": cache_key,
436+
"job": job,
437+
}
438+
439+
cancel = long.get("cancel")
440+
if cancel:
441+
data["cancel"] = cancel
442+
443+
progress_default = long.get("progressDefault")
444+
if progress_default:
445+
data["progressDefault"] = {
446+
str(o): x
447+
for o, x in zip(progress_outputs, progress_default)
448+
}
449+
return to_json(data)
450+
if progress_outputs:
451+
# Get the progress before the result as it would be erased after the results.
452+
progress = callback_manager.get_progress(cache_key)
453+
if progress:
454+
response["progress"] = {
455+
str(x): progress[i] for i, x in enumerate(progress_outputs)
456+
}
457+
458+
output_value = callback_manager.get_result(cache_key, job_id)
459+
# Must get job_running after get_result since get_results terminates it.
460+
job_running = callback_manager.job_running(job_id)
461+
if not job_running and output_value is callback_manager.UNDEFINED:
462+
# Job canceled -> no output to close the loop.
463+
output_value = NoUpdate()
464+
465+
elif (
466+
isinstance(output_value, dict)
467+
and "long_callback_error" in output_value
468+
):
469+
error = output_value.get("long_callback_error", {})
470+
exc = LongCallbackError(
471+
f"An error occurred inside a long callback: {error['msg']}\n{error['tb']}"
472+
)
473+
if error_handler:
474+
output_value = error_handler(exc)
475+
476+
if output_value is None:
477+
output_value = NoUpdate()
478+
# set_props from the error handler uses the original ctx
479+
# instead of manager.get_updated_props since it runs in the
480+
# request process.
481+
has_update = (
482+
_set_side_update(callback_ctx, response)
483+
or output_value is not None
484+
)
485+
else:
486+
raise exc
487+
488+
if job_running and output_value is not callback_manager.UNDEFINED:
489+
# cached results.
490+
callback_manager.terminate_job(job_id)
491+
492+
if multi and isinstance(output_value, (list, tuple)):
493+
output_value = [
494+
NoUpdate() if NoUpdate.is_no_update(r) else r
495+
for r in output_value
496+
]
497+
updated_props = callback_manager.get_updated_props(cache_key)
498+
if len(updated_props) > 0:
499+
response["sideUpdate"] = updated_props
500+
has_update = True
501+
502+
if output_value is callback_manager.UNDEFINED:
503+
return to_json(response)
504+
else:
505+
try:
506+
output_value = await _async_invoke_callback(func, *func_args, **func_kwargs)
507+
except PreventUpdate as err:
508+
raise err
509+
except Exception as err: # pylint: disable=broad-exception-caught
510+
if error_handler:
511+
output_value = error_handler(err)
512+
513+
# If the error returns nothing, automatically puts NoUpdate for response.
514+
if output_value is None and has_output:
515+
output_value = NoUpdate()
516+
else:
517+
raise err
518+
519+
component_ids = collections.defaultdict(dict)
520+
521+
if has_output:
522+
if not multi:
523+
output_value, output_spec = [output_value], [output_spec]
524+
flat_output_values = output_value
525+
else:
526+
if isinstance(output_value, (list, tuple)):
527+
# For multi-output, allow top-level collection to be
528+
# list or tuple
529+
output_value = list(output_value)
530+
531+
if NoUpdate.is_no_update(output_value):
532+
flat_output_values = [output_value]
533+
else:
534+
# Flatten grouping and validate grouping structure
535+
flat_output_values = flatten_grouping(output_value, output)
536+
537+
if not NoUpdate.is_no_update(output_value):
538+
_validate.validate_multi_return(
539+
output_spec, flat_output_values, callback_id
540+
)
541+
542+
for val, spec in zip(flat_output_values, output_spec):
543+
if NoUpdate.is_no_update(val):
544+
continue
545+
for vali, speci in (
546+
zip(val, spec) if isinstance(spec, list) else [[val, spec]]
547+
):
548+
if not NoUpdate.is_no_update(vali):
549+
has_update = True
550+
id_str = stringify_id(speci["id"])
551+
prop = clean_property_name(speci["property"])
552+
component_ids[id_str][prop] = vali
553+
else:
554+
if output_value is not None:
555+
raise InvalidCallbackReturnValue(
556+
f"No-output callback received return value: {output_value}"
557+
)
558+
output_value = []
559+
flat_output_values = []
560+
561+
if not long:
562+
has_update = _set_side_update(callback_ctx, response) or has_update
563+
564+
if not has_update:
565+
raise PreventUpdate
566+
567+
response["response"] = component_ids
568+
569+
if len(ComponentRegistry.registry) != len(original_packages):
570+
diff_packages = list(
571+
set(ComponentRegistry.registry).difference(original_packages)
572+
)
573+
if not allow_dynamic_callbacks:
574+
raise ImportedInsideCallbackError(
575+
f"Component librar{'y' if len(diff_packages) == 1 else 'ies'} was imported during callback.\n"
576+
"You can set `_allow_dynamic_callbacks` to allow for development purpose only."
577+
)
578+
dist = app.get_dist(diff_packages)
579+
response["dist"] = dist
580+
581+
try:
582+
jsonResponse = to_json(response)
583+
except TypeError:
584+
_validate.fail_callback_output(output_value, output)
585+
586+
return jsonResponse
587+
588+
def add_context(*args, **kwargs):
363589
output_spec = kwargs.pop("outputs_list")
364590
app_callback_manager = kwargs.pop("long_callback_manager", None)
365591

@@ -499,7 +725,7 @@ async def add_context(*args, **kwargs):
499725
return to_json(response)
500726
else:
501727
try:
502-
output_value = await _invoke_callback(func, *func_args, **func_kwargs)
728+
output_value = _invoke_callback(func, *func_args, **func_kwargs)
503729
except PreventUpdate as err:
504730
raise err
505731
except Exception as err: # pylint: disable=broad-exception-caught
@@ -581,7 +807,10 @@ async def add_context(*args, **kwargs):
581807

582808
return jsonResponse
583809

584-
callback_map[callback_id]["callback"] = add_context
810+
if use_async:
811+
callback_map[callback_id]["callback"] = async_add_context
812+
else:
813+
callback_map[callback_id]["callback"] = add_context
585814

586815
return func
587816

0 commit comments

Comments
 (0)