|
20 | 20 | from pathlib import Path
|
21 | 21 | import pytest
|
22 | 22 | from shutil import rmtree
|
| 23 | +from time import time |
23 | 24 | from typing import List, TYPE_CHECKING, Set, Tuple, Union
|
24 | 25 |
|
25 | 26 | from cylc.flow.config import WorkflowConfig
|
|
32 | 33 | install as cylc_install,
|
33 | 34 | get_option_parser as install_gop
|
34 | 35 | )
|
| 36 | +from cylc.flow.util import serialise |
35 | 37 | from cylc.flow.wallclock import get_current_time_string
|
36 | 38 | from cylc.flow.workflow_files import infer_latest_run_from_id
|
| 39 | +from cylc.flow.workflow_status import StopMode |
37 | 40 |
|
38 | 41 | from .utils import _rm_if_empty
|
39 | 42 | from .utils.flow_tools import (
|
@@ -473,3 +476,140 @@ def _inner(source, **kwargs):
|
473 | 476 | workflow_id = infer_latest_run_from_id(workflow_id)
|
474 | 477 | return workflow_id
|
475 | 478 | yield _inner
|
| 479 | + |
| 480 | + |
| 481 | +@pytest.fixture |
| 482 | +def reflog(): |
| 483 | + """Integration test version of the --reflog CLI option. |
| 484 | +
|
| 485 | + This returns a set which captures task triggers. |
| 486 | +
|
| 487 | + Note, you'll need to call this on the scheduler *after* you have started |
| 488 | + it. |
| 489 | +
|
| 490 | + Args: |
| 491 | + schd: |
| 492 | + The scheduler to capture triggering information for. |
| 493 | + flow_nums: |
| 494 | + If True, the flow numbers of the task being triggered will be added |
| 495 | + to the end of each entry. |
| 496 | +
|
| 497 | + Returns: |
| 498 | + tuple |
| 499 | +
|
| 500 | + (task, triggers): |
| 501 | + If flow_nums == False |
| 502 | + (task, flow_nums, triggers): |
| 503 | + If flow_nums == True |
| 504 | +
|
| 505 | + task: |
| 506 | + The [relative] task ID e.g. "1/a". |
| 507 | + flow_nums: |
| 508 | + The serialised flow nums e.g. ["1"]. |
| 509 | + triggers: |
| 510 | + Sorted tuple of the trigger IDs, e.g. ("1/a", "2/b"). |
| 511 | +
|
| 512 | + """ |
| 513 | + |
| 514 | + def _reflog(schd, flow_nums=False): |
| 515 | + submit_task_jobs = schd.task_job_mgr.submit_task_jobs |
| 516 | + triggers = set() |
| 517 | + |
| 518 | + def _submit_task_jobs(*args, **kwargs): |
| 519 | + nonlocal submit_task_jobs, triggers, flow_nums |
| 520 | + itasks = submit_task_jobs(*args, **kwargs) |
| 521 | + for itask in itasks: |
| 522 | + deps = tuple(sorted(itask.state.get_resolved_dependencies())) |
| 523 | + if flow_nums: |
| 524 | + triggers.add( |
| 525 | + (itask.identity, serialise(itask.flow_nums), deps or None) |
| 526 | + ) |
| 527 | + else: |
| 528 | + triggers.add((itask.identity, deps or None)) |
| 529 | + return itasks |
| 530 | + |
| 531 | + schd.task_job_mgr.submit_task_jobs = _submit_task_jobs |
| 532 | + |
| 533 | + return triggers |
| 534 | + |
| 535 | + return _reflog |
| 536 | + |
| 537 | + |
| 538 | +@pytest.fixture |
| 539 | +def complete(): |
| 540 | + """Wait for the workflow, or tasks within it to complete. |
| 541 | +
|
| 542 | + Args: |
| 543 | + schd: |
| 544 | + The scheduler to await. |
| 545 | + tokens_list: |
| 546 | + If specified, this will wait for the tasks represented by these |
| 547 | + tokens to be marked as completed by the task pool. |
| 548 | + stop_mode: |
| 549 | + If tokens_list is not provided, this will wait for the scheduler |
| 550 | + to be shutdown with the specified mode (default = AUTO, i.e. |
| 551 | + workflow completed normally). |
| 552 | + timeout: |
| 553 | + Max time to wait for the condition to be met. |
| 554 | +
|
| 555 | + Note, if you need to increase this, you might want to rethink your |
| 556 | + test. |
| 557 | +
|
| 558 | + Note, use this timeout rather than wrapping the complete call with |
| 559 | + async_timeout (handles shutdown logic more cleanly). |
| 560 | +
|
| 561 | + """ |
| 562 | + async def _complete( |
| 563 | + schd, |
| 564 | + *tokens_list, |
| 565 | + stop_mode=StopMode.AUTO, |
| 566 | + timeout=60, |
| 567 | + ): |
| 568 | + start_time = time() |
| 569 | + tokens_list = [tokens.task for tokens in tokens_list] |
| 570 | + |
| 571 | + # capture task completion |
| 572 | + remove_if_complete = schd.pool.remove_if_complete |
| 573 | + |
| 574 | + def _remove_if_complete(itask): |
| 575 | + ret = remove_if_complete(itask) |
| 576 | + if ret and itask.tokens.task in tokens_list: |
| 577 | + tokens_list.remove(itask.tokens.task) |
| 578 | + return ret |
| 579 | + |
| 580 | + schd.pool.remove_if_complete = _remove_if_complete |
| 581 | + |
| 582 | + # capture workflow shutdown |
| 583 | + set_stop = schd._set_stop |
| 584 | + has_shutdown = False |
| 585 | + |
| 586 | + def _set_stop(mode=None): |
| 587 | + nonlocal has_shutdown, stop_mode |
| 588 | + if mode == stop_mode: |
| 589 | + has_shutdown = True |
| 590 | + return set_stop(mode) |
| 591 | + else: |
| 592 | + set_stop(mode) |
| 593 | + raise Exception(f'Workflow bailed with stop mode = {mode}') |
| 594 | + |
| 595 | + schd._set_stop = _set_stop |
| 596 | + |
| 597 | + # determine the completion condition |
| 598 | + if tokens_list: |
| 599 | + condition = lambda: bool(tokens_list) |
| 600 | + else: |
| 601 | + condition = lambda: bool(not has_shutdown) |
| 602 | + |
| 603 | + # wait for the condition to be met |
| 604 | + while condition(): |
| 605 | + # allow the main loop to advance |
| 606 | + await asyncio.sleep(0) |
| 607 | + if time() - start_time > timeout: |
| 608 | + raise Exception( |
| 609 | + f'Timeout waiting for {", ".join(map(str, tokens_list))}' |
| 610 | + ) |
| 611 | + |
| 612 | + # restore regular shutdown logic |
| 613 | + schd._set_stop = set_stop |
| 614 | + |
| 615 | + return _complete |
0 commit comments