|
40 | 40 | from ._callback_context import context_value
|
41 | 41 |
|
42 | 42 |
|
43 |
| -async def _invoke_callback(func, *args, **kwargs): |
| 43 | +async def _async_invoke_callback(func, *args, **kwargs): |
44 | 44 | # Check if the function is a coroutine function
|
45 | 45 | if asyncio.iscoroutinefunction(func):
|
46 | 46 | return await func(*args, **kwargs)
|
47 | 47 | else:
|
48 | 48 | # If the function is not a coroutine, call it directly
|
49 | 49 | return func(*args, **kwargs)
|
50 | 50 |
|
| 51 | +def _invoke_callback(func, *args, **kwargs): |
| 52 | + return func(*args, **kwargs) |
| 53 | + |
51 | 54 |
|
52 | 55 | class NoUpdate:
|
53 | 56 | def to_plotly_json(self): # pylint: disable=no-self-use
|
@@ -321,6 +324,7 @@ def register_callback(
|
321 | 324 | manager = _kwargs.get("manager")
|
322 | 325 | running = _kwargs.get("running")
|
323 | 326 | on_error = _kwargs.get("on_error")
|
| 327 | + use_async = _kwargs.get("app_use_async") |
324 | 328 | if running is not None:
|
325 | 329 | if not isinstance(running[0], (list, tuple)):
|
326 | 330 | running = [running]
|
@@ -359,7 +363,229 @@ def wrap_func(func):
|
359 | 363 | )
|
360 | 364 |
|
361 | 365 | @wraps(func)
|
362 |
| - async def add_context(*args, **kwargs): |
| 366 | + async def async_add_context(*args, **kwargs): |
| 367 | + output_spec = kwargs.pop("outputs_list") |
| 368 | + app_callback_manager = kwargs.pop("long_callback_manager", None) |
| 369 | + |
| 370 | + callback_ctx = kwargs.pop( |
| 371 | + "callback_context", AttributeDict({"updated_props": {}}) |
| 372 | + ) |
| 373 | + app = kwargs.pop("app", None) |
| 374 | + callback_manager = long and long.get("manager", app_callback_manager) |
| 375 | + error_handler = on_error or kwargs.pop("app_on_error", None) |
| 376 | + original_packages = set(ComponentRegistry.registry) |
| 377 | + |
| 378 | + if has_output: |
| 379 | + _validate.validate_output_spec(insert_output, output_spec, Output) |
| 380 | + |
| 381 | + context_value.set(callback_ctx) |
| 382 | + |
| 383 | + func_args, func_kwargs = _validate.validate_and_group_input_args( |
| 384 | + args, inputs_state_indices |
| 385 | + ) |
| 386 | + |
| 387 | + response: dict = {"multi": True} |
| 388 | + has_update = False |
| 389 | + |
| 390 | + if long is not None: |
| 391 | + if not callback_manager: |
| 392 | + raise MissingLongCallbackManagerError( |
| 393 | + "Running `long` callbacks requires a manager to be installed.\n" |
| 394 | + "Available managers:\n" |
| 395 | + "- Diskcache (`pip install dash[diskcache]`) to run callbacks in a separate Process" |
| 396 | + " and store results on the local filesystem.\n" |
| 397 | + "- Celery (`pip install dash[celery]`) to run callbacks in a celery worker" |
| 398 | + " and store results on redis.\n" |
| 399 | + ) |
| 400 | + |
| 401 | + progress_outputs = long.get("progress") |
| 402 | + cache_key = flask.request.args.get("cacheKey") |
| 403 | + job_id = flask.request.args.get("job") |
| 404 | + old_job = flask.request.args.getlist("oldJob") |
| 405 | + |
| 406 | + current_key = callback_manager.build_cache_key( |
| 407 | + func, |
| 408 | + # Inputs provided as dict is kwargs. |
| 409 | + func_args if func_args else func_kwargs, |
| 410 | + long.get("cache_args_to_ignore", []), |
| 411 | + ) |
| 412 | + |
| 413 | + if old_job: |
| 414 | + for job in old_job: |
| 415 | + callback_manager.terminate_job(job) |
| 416 | + |
| 417 | + if not cache_key: |
| 418 | + cache_key = current_key |
| 419 | + |
| 420 | + job_fn = callback_manager.func_registry.get(long_key) |
| 421 | + |
| 422 | + ctx_value = AttributeDict(**context_value.get()) |
| 423 | + ctx_value.ignore_register_page = True |
| 424 | + ctx_value.pop("background_callback_manager") |
| 425 | + ctx_value.pop("dash_response") |
| 426 | + |
| 427 | + job = callback_manager.call_job_fn( |
| 428 | + cache_key, |
| 429 | + job_fn, |
| 430 | + func_args if func_args else func_kwargs, |
| 431 | + ctx_value, |
| 432 | + ) |
| 433 | + |
| 434 | + data = { |
| 435 | + "cacheKey": cache_key, |
| 436 | + "job": job, |
| 437 | + } |
| 438 | + |
| 439 | + cancel = long.get("cancel") |
| 440 | + if cancel: |
| 441 | + data["cancel"] = cancel |
| 442 | + |
| 443 | + progress_default = long.get("progressDefault") |
| 444 | + if progress_default: |
| 445 | + data["progressDefault"] = { |
| 446 | + str(o): x |
| 447 | + for o, x in zip(progress_outputs, progress_default) |
| 448 | + } |
| 449 | + return to_json(data) |
| 450 | + if progress_outputs: |
| 451 | + # Get the progress before the result as it would be erased after the results. |
| 452 | + progress = callback_manager.get_progress(cache_key) |
| 453 | + if progress: |
| 454 | + response["progress"] = { |
| 455 | + str(x): progress[i] for i, x in enumerate(progress_outputs) |
| 456 | + } |
| 457 | + |
| 458 | + output_value = callback_manager.get_result(cache_key, job_id) |
| 459 | + # Must get job_running after get_result since get_results terminates it. |
| 460 | + job_running = callback_manager.job_running(job_id) |
| 461 | + if not job_running and output_value is callback_manager.UNDEFINED: |
| 462 | + # Job canceled -> no output to close the loop. |
| 463 | + output_value = NoUpdate() |
| 464 | + |
| 465 | + elif ( |
| 466 | + isinstance(output_value, dict) |
| 467 | + and "long_callback_error" in output_value |
| 468 | + ): |
| 469 | + error = output_value.get("long_callback_error", {}) |
| 470 | + exc = LongCallbackError( |
| 471 | + f"An error occurred inside a long callback: {error['msg']}\n{error['tb']}" |
| 472 | + ) |
| 473 | + if error_handler: |
| 474 | + output_value = error_handler(exc) |
| 475 | + |
| 476 | + if output_value is None: |
| 477 | + output_value = NoUpdate() |
| 478 | + # set_props from the error handler uses the original ctx |
| 479 | + # instead of manager.get_updated_props since it runs in the |
| 480 | + # request process. |
| 481 | + has_update = ( |
| 482 | + _set_side_update(callback_ctx, response) |
| 483 | + or output_value is not None |
| 484 | + ) |
| 485 | + else: |
| 486 | + raise exc |
| 487 | + |
| 488 | + if job_running and output_value is not callback_manager.UNDEFINED: |
| 489 | + # cached results. |
| 490 | + callback_manager.terminate_job(job_id) |
| 491 | + |
| 492 | + if multi and isinstance(output_value, (list, tuple)): |
| 493 | + output_value = [ |
| 494 | + NoUpdate() if NoUpdate.is_no_update(r) else r |
| 495 | + for r in output_value |
| 496 | + ] |
| 497 | + updated_props = callback_manager.get_updated_props(cache_key) |
| 498 | + if len(updated_props) > 0: |
| 499 | + response["sideUpdate"] = updated_props |
| 500 | + has_update = True |
| 501 | + |
| 502 | + if output_value is callback_manager.UNDEFINED: |
| 503 | + return to_json(response) |
| 504 | + else: |
| 505 | + try: |
| 506 | + output_value = await _async_invoke_callback(func, *func_args, **func_kwargs) |
| 507 | + except PreventUpdate as err: |
| 508 | + raise err |
| 509 | + except Exception as err: # pylint: disable=broad-exception-caught |
| 510 | + if error_handler: |
| 511 | + output_value = error_handler(err) |
| 512 | + |
| 513 | + # If the error returns nothing, automatically puts NoUpdate for response. |
| 514 | + if output_value is None and has_output: |
| 515 | + output_value = NoUpdate() |
| 516 | + else: |
| 517 | + raise err |
| 518 | + |
| 519 | + component_ids = collections.defaultdict(dict) |
| 520 | + |
| 521 | + if has_output: |
| 522 | + if not multi: |
| 523 | + output_value, output_spec = [output_value], [output_spec] |
| 524 | + flat_output_values = output_value |
| 525 | + else: |
| 526 | + if isinstance(output_value, (list, tuple)): |
| 527 | + # For multi-output, allow top-level collection to be |
| 528 | + # list or tuple |
| 529 | + output_value = list(output_value) |
| 530 | + |
| 531 | + if NoUpdate.is_no_update(output_value): |
| 532 | + flat_output_values = [output_value] |
| 533 | + else: |
| 534 | + # Flatten grouping and validate grouping structure |
| 535 | + flat_output_values = flatten_grouping(output_value, output) |
| 536 | + |
| 537 | + if not NoUpdate.is_no_update(output_value): |
| 538 | + _validate.validate_multi_return( |
| 539 | + output_spec, flat_output_values, callback_id |
| 540 | + ) |
| 541 | + |
| 542 | + for val, spec in zip(flat_output_values, output_spec): |
| 543 | + if NoUpdate.is_no_update(val): |
| 544 | + continue |
| 545 | + for vali, speci in ( |
| 546 | + zip(val, spec) if isinstance(spec, list) else [[val, spec]] |
| 547 | + ): |
| 548 | + if not NoUpdate.is_no_update(vali): |
| 549 | + has_update = True |
| 550 | + id_str = stringify_id(speci["id"]) |
| 551 | + prop = clean_property_name(speci["property"]) |
| 552 | + component_ids[id_str][prop] = vali |
| 553 | + else: |
| 554 | + if output_value is not None: |
| 555 | + raise InvalidCallbackReturnValue( |
| 556 | + f"No-output callback received return value: {output_value}" |
| 557 | + ) |
| 558 | + output_value = [] |
| 559 | + flat_output_values = [] |
| 560 | + |
| 561 | + if not long: |
| 562 | + has_update = _set_side_update(callback_ctx, response) or has_update |
| 563 | + |
| 564 | + if not has_update: |
| 565 | + raise PreventUpdate |
| 566 | + |
| 567 | + response["response"] = component_ids |
| 568 | + |
| 569 | + if len(ComponentRegistry.registry) != len(original_packages): |
| 570 | + diff_packages = list( |
| 571 | + set(ComponentRegistry.registry).difference(original_packages) |
| 572 | + ) |
| 573 | + if not allow_dynamic_callbacks: |
| 574 | + raise ImportedInsideCallbackError( |
| 575 | + f"Component librar{'y' if len(diff_packages) == 1 else 'ies'} was imported during callback.\n" |
| 576 | + "You can set `_allow_dynamic_callbacks` to allow for development purpose only." |
| 577 | + ) |
| 578 | + dist = app.get_dist(diff_packages) |
| 579 | + response["dist"] = dist |
| 580 | + |
| 581 | + try: |
| 582 | + jsonResponse = to_json(response) |
| 583 | + except TypeError: |
| 584 | + _validate.fail_callback_output(output_value, output) |
| 585 | + |
| 586 | + return jsonResponse |
| 587 | + |
| 588 | + def add_context(*args, **kwargs): |
363 | 589 | output_spec = kwargs.pop("outputs_list")
|
364 | 590 | app_callback_manager = kwargs.pop("long_callback_manager", None)
|
365 | 591 |
|
@@ -499,7 +725,7 @@ async def add_context(*args, **kwargs):
|
499 | 725 | return to_json(response)
|
500 | 726 | else:
|
501 | 727 | try:
|
502 |
| - output_value = await _invoke_callback(func, *func_args, **func_kwargs) |
| 728 | + output_value = _invoke_callback(func, *func_args, **func_kwargs) |
503 | 729 | except PreventUpdate as err:
|
504 | 730 | raise err
|
505 | 731 | except Exception as err: # pylint: disable=broad-exception-caught
|
@@ -581,7 +807,10 @@ async def add_context(*args, **kwargs):
|
581 | 807 |
|
582 | 808 | return jsonResponse
|
583 | 809 |
|
584 |
| - callback_map[callback_id]["callback"] = add_context |
| 810 | + if use_async: |
| 811 | + callback_map[callback_id]["callback"] = async_add_context |
| 812 | + else: |
| 813 | + callback_map[callback_id]["callback"] = add_context |
585 | 814 |
|
586 | 815 | return func
|
587 | 816 |
|
|
0 commit comments