|
7 | 7 | import filecmp |
8 | 8 | import tempfile |
9 | 9 | import threading |
| 10 | +from asyncio import gather |
10 | 11 | from pathlib import Path |
11 | 12 | from typing import Any, Callable, Dict, Type, Union |
12 | 13 | from uuid import uuid4 |
|
15 | 16 | import pytest |
16 | 17 | import sqlalchemy as sa |
17 | 18 | from simcore_sdk import node_ports_v2 |
| 19 | +from simcore_sdk.node_ports_common.exceptions import UnboundPortError |
18 | 20 | from simcore_sdk.node_ports_v2 import exceptions |
19 | 21 | from simcore_sdk.node_ports_v2.links import ItemConcreteValue |
20 | 22 | from simcore_sdk.node_ports_v2.nodeports_v2 import Nodeports |
@@ -554,3 +556,111 @@ async def test_file_mapping( |
554 | 556 | assert received_file_link["path"] == file_id |
555 | 557 | # received a new eTag |
556 | 558 | assert received_file_link["eTag"] |
| 559 | + |
| 560 | + |
| 561 | +@pytest.fixture |
| 562 | +def int_item_value() -> int: |
| 563 | + return 42 |
| 564 | + |
| 565 | + |
| 566 | +@pytest.fixture |
| 567 | +def parallel_int_item_value() -> int: |
| 568 | + return 142 |
| 569 | + |
| 570 | + |
| 571 | +@pytest.fixture |
| 572 | +def port_count() -> int: |
| 573 | + # the issue manifests from 4 ports onwards |
| 574 | + # going for many more ports to be sure issue |
| 575 | + # always occurs in CI or locally |
| 576 | + return 20 |
| 577 | + |
| 578 | + |
| 579 | +async def test_regression_concurrent_port_update_fails( |
| 580 | + user_id: int, |
| 581 | + project_id: str, |
| 582 | + node_uuid: str, |
| 583 | + special_configuration: Callable, |
| 584 | + int_item_value: int, |
| 585 | + parallel_int_item_value: int, |
| 586 | + port_count: int, |
| 587 | +) -> None: |
| 588 | + """ |
| 589 | + when using `await PORTS.outputs` test will fail |
| 590 | + an unexpected status will end up in the database |
| 591 | + """ |
| 592 | + |
| 593 | + outputs = [(f"value_{i}", "integer", None) for i in range(port_count)] |
| 594 | + config_dict, _, _ = special_configuration(inputs=[], outputs=outputs) |
| 595 | + |
| 596 | + PORTS = await node_ports_v2.ports( |
| 597 | + user_id=user_id, project_id=project_id, node_uuid=node_uuid |
| 598 | + ) |
| 599 | + await check_config_valid(PORTS, config_dict) |
| 600 | + |
| 601 | + # when writing in serial these are expected to work |
| 602 | + for item_key, _, _ in outputs: |
| 603 | + await (await PORTS.outputs)[item_key].set(int_item_value) |
| 604 | + assert (await PORTS.outputs)[item_key].value == int_item_value |
| 605 | + |
| 606 | + # when writing in parallel and reading back, |
| 607 | + # they fail, with enough concurrency |
| 608 | + async def _upload_task(item_key: str) -> None: |
| 609 | + await (await PORTS.outputs)[item_key].set(parallel_int_item_value) |
| 610 | + |
| 611 | + # updating in parallel creates a race condition |
| 612 | + results = await gather(*[_upload_task(item_key) for item_key, _, _ in outputs]) |
| 613 | + assert len(results) == port_count |
| 614 | + |
| 615 | + # since a race condition was created when uploading values in parallel |
| 616 | + # it is expected to find at least one mismatching value here |
| 617 | + with pytest.raises(AssertionError) as exc_info: |
| 618 | + for item_key, _, _ in outputs: |
| 619 | + assert (await PORTS.outputs)[item_key].value == parallel_int_item_value |
| 620 | + assert ( |
| 621 | + exc_info.value.args[0] |
| 622 | + == f"assert {int_item_value} == {parallel_int_item_value}\n +{int_item_value}\n -{parallel_int_item_value}" |
| 623 | + ) |
| 624 | + |
| 625 | + |
| 626 | +async def test_batch_update_inputs_outputs( |
| 627 | + user_id: int, |
| 628 | + project_id: str, |
| 629 | + node_uuid: str, |
| 630 | + special_configuration: Callable, |
| 631 | + port_count: int, |
| 632 | +) -> None: |
| 633 | + outputs = [(f"value_out_{i}", "integer", None) for i in range(port_count)] |
| 634 | + inputs = [(f"value_in_{i}", "integer", None) for i in range(port_count)] |
| 635 | + config_dict, _, _ = special_configuration(inputs=inputs, outputs=outputs) |
| 636 | + |
| 637 | + PORTS = await node_ports_v2.ports( |
| 638 | + user_id=user_id, project_id=project_id, node_uuid=node_uuid |
| 639 | + ) |
| 640 | + await check_config_valid(PORTS, config_dict) |
| 641 | + |
| 642 | + await PORTS.set_multiple( |
| 643 | + {port.key: k for k, port in enumerate((await PORTS.outputs).values())} |
| 644 | + ) |
| 645 | + await PORTS.set_multiple( |
| 646 | + { |
| 647 | + port.key: k |
| 648 | + for k, port in enumerate((await PORTS.inputs).values(), start=1000) |
| 649 | + } |
| 650 | + ) |
| 651 | + |
| 652 | + ports_outputs = await PORTS.outputs |
| 653 | + ports_inputs = await PORTS.inputs |
| 654 | + for k, asd in enumerate(outputs): |
| 655 | + item_key, _, _ = asd |
| 656 | + assert ports_outputs[item_key].value == k |
| 657 | + assert await ports_outputs[item_key].get() == k |
| 658 | + |
| 659 | + for k, asd in enumerate(inputs, start=1000): |
| 660 | + item_key, _, _ = asd |
| 661 | + assert ports_inputs[item_key].value == k |
| 662 | + assert await ports_inputs[item_key].get() == k |
| 663 | + |
| 664 | + # test missing key raises error |
| 665 | + with pytest.raises(UnboundPortError): |
| 666 | + await PORTS.set_multiple({"missing_key_in_both": 123132}) |
0 commit comments