Skip to content

Commit aff0b8e

Browse files
GitHKAndrei Neagu
andauthored
🐛Fixing nodeports_v2 port.set() race condition (ITISFoundation#2587)
* adding regression test * semplified and fixed regression test * added a better method to save the status of the nodeports * updated docstring * adding API to save values in batch * adding tests * fixing unittests * refactor to use different interface * ports is now hashable. refactoring interface * fixed interface and tets * added int support * removed hashing * port has to be hashable * not required to update ports before writing * removed hashable, require by tests * updating values to ports in parallel * current refactor * refactor and added tests Co-authored-by: Andrei Neagu <[email protected]>
1 parent bac03b9 commit aff0b8e

File tree

4 files changed

+145
-2
lines changed

4 files changed

+145
-2
lines changed

packages/simcore-sdk/src/simcore_sdk/node_ports_v2/nodeports_v2.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import logging
2+
from collections import deque
23
from pathlib import Path
3-
from typing import Any, Callable, Coroutine, Type
4+
from typing import Any, Callable, Coroutine, Dict, Type
45

56
from pydantic import BaseModel, Field
7+
from servicelib.utils import logged_gather
68

79
from ..node_ports_common.dbmanager import DBManager
810
from ..node_ports_common.exceptions import PortNotFound, UnboundPortError
@@ -99,3 +101,22 @@ async def _auto_update_from_db(self) -> None:
99101
self.internal_inputs[input_key]._node_ports = self
100102
for output_key in self.internal_outputs:
101103
self.internal_outputs[output_key]._node_ports = self
104+
105+
async def set_multiple(self, port_values: Dict[str, ItemConcreteValue]) -> None:
106+
"""
107+
Sets the provided values to the respective input or output ports
108+
Only supports port_key by name, not able to distinguish between inputs
109+
and outputs using the index.
110+
"""
111+
tasks = deque()
112+
for port_key, value in port_values.items():
113+
# pylint: disable=protected-access
114+
try:
115+
tasks.append(self.internal_outputs[port_key]._set(value))
116+
except UnboundPortError:
117+
# not available try inputs
118+
# if this fails it will raise another exception
119+
tasks.append(self.internal_inputs[port_key]._set(value))
120+
121+
await logged_gather(*tasks)
122+
await self.save_to_db_cb(self)

packages/simcore-sdk/src/simcore_sdk/node_ports_v2/port.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ async def get(self) -> Optional[ItemConcreteValue]:
106106

107107
return self._py_value_converter(value)
108108

109-
async def set(self, new_value: ItemConcreteValue) -> None:
109+
async def _set(self, new_value: ItemConcreteValue) -> None:
110110
log.debug(
111111
"setting %s[%s] with value %s", self.key, self.property_type, new_value
112112
)
@@ -129,4 +129,8 @@ async def set(self, new_value: ItemConcreteValue) -> None:
129129

130130
self.value = final_value
131131
self._used_default_value = False
132+
133+
async def set(self, new_value: ItemConcreteValue) -> None:
134+
"""sets a value to the port, by default it is also stored in the database"""
135+
await self._set(new_value)
132136
await self._node_ports.save_to_db_cb(self._node_ports)

packages/simcore-sdk/tests/integration/test_node_ports_v2_nodeports2.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import filecmp
88
import tempfile
99
import threading
10+
from asyncio import gather
1011
from pathlib import Path
1112
from typing import Any, Callable, Dict, Type, Union
1213
from uuid import uuid4
@@ -15,6 +16,7 @@
1516
import pytest
1617
import sqlalchemy as sa
1718
from simcore_sdk import node_ports_v2
19+
from simcore_sdk.node_ports_common.exceptions import UnboundPortError
1820
from simcore_sdk.node_ports_v2 import exceptions
1921
from simcore_sdk.node_ports_v2.links import ItemConcreteValue
2022
from simcore_sdk.node_ports_v2.nodeports_v2 import Nodeports
@@ -554,3 +556,111 @@ async def test_file_mapping(
554556
assert received_file_link["path"] == file_id
555557
# received a new eTag
556558
assert received_file_link["eTag"]
559+
560+
561+
@pytest.fixture
562+
def int_item_value() -> int:
563+
return 42
564+
565+
566+
@pytest.fixture
567+
def parallel_int_item_value() -> int:
568+
return 142
569+
570+
571+
@pytest.fixture
572+
def port_count() -> int:
573+
# the issue manifests from 4 ports onwards
574+
# going for many more ports to be sure issue
575+
# always occurs in CI or locally
576+
return 20
577+
578+
579+
async def test_regression_concurrent_port_update_fails(
580+
user_id: int,
581+
project_id: str,
582+
node_uuid: str,
583+
special_configuration: Callable,
584+
int_item_value: int,
585+
parallel_int_item_value: int,
586+
port_count: int,
587+
) -> None:
588+
"""
589+
when using `await PORTS.outputs` test will fail
590+
an unexpected status will end up in the database
591+
"""
592+
593+
outputs = [(f"value_{i}", "integer", None) for i in range(port_count)]
594+
config_dict, _, _ = special_configuration(inputs=[], outputs=outputs)
595+
596+
PORTS = await node_ports_v2.ports(
597+
user_id=user_id, project_id=project_id, node_uuid=node_uuid
598+
)
599+
await check_config_valid(PORTS, config_dict)
600+
601+
# when writing in serial these are expected to work
602+
for item_key, _, _ in outputs:
603+
await (await PORTS.outputs)[item_key].set(int_item_value)
604+
assert (await PORTS.outputs)[item_key].value == int_item_value
605+
606+
# when writing in parallel and reading back,
607+
# they fail, with enough concurrency
608+
async def _upload_task(item_key: str) -> None:
609+
await (await PORTS.outputs)[item_key].set(parallel_int_item_value)
610+
611+
# updating in parallel creates a race condition
612+
results = await gather(*[_upload_task(item_key) for item_key, _, _ in outputs])
613+
assert len(results) == port_count
614+
615+
# since a race condition was created when uploading values in parallel
616+
# it is expected to find at least one mismatching value here
617+
with pytest.raises(AssertionError) as exc_info:
618+
for item_key, _, _ in outputs:
619+
assert (await PORTS.outputs)[item_key].value == parallel_int_item_value
620+
assert (
621+
exc_info.value.args[0]
622+
== f"assert {int_item_value} == {parallel_int_item_value}\n +{int_item_value}\n -{parallel_int_item_value}"
623+
)
624+
625+
626+
async def test_batch_update_inputs_outputs(
627+
user_id: int,
628+
project_id: str,
629+
node_uuid: str,
630+
special_configuration: Callable,
631+
port_count: int,
632+
) -> None:
633+
outputs = [(f"value_out_{i}", "integer", None) for i in range(port_count)]
634+
inputs = [(f"value_in_{i}", "integer", None) for i in range(port_count)]
635+
config_dict, _, _ = special_configuration(inputs=inputs, outputs=outputs)
636+
637+
PORTS = await node_ports_v2.ports(
638+
user_id=user_id, project_id=project_id, node_uuid=node_uuid
639+
)
640+
await check_config_valid(PORTS, config_dict)
641+
642+
await PORTS.set_multiple(
643+
{port.key: k for k, port in enumerate((await PORTS.outputs).values())}
644+
)
645+
await PORTS.set_multiple(
646+
{
647+
port.key: k
648+
for k, port in enumerate((await PORTS.inputs).values(), start=1000)
649+
}
650+
)
651+
652+
ports_outputs = await PORTS.outputs
653+
ports_inputs = await PORTS.inputs
654+
for k, asd in enumerate(outputs):
655+
item_key, _, _ = asd
656+
assert ports_outputs[item_key].value == k
657+
assert await ports_outputs[item_key].get() == k
658+
659+
for k, asd in enumerate(inputs, start=1000):
660+
item_key, _, _ = asd
661+
assert ports_inputs[item_key].value == k
662+
assert await ports_inputs[item_key].get() == k
663+
664+
# test missing key raises error
665+
with pytest.raises(UnboundPortError):
666+
await PORTS.set_multiple({"missing_key_in_both": 123132})

packages/simcore-sdk/tests/unit/test_node_ports_v2_nodeports_v2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,14 @@ async def mock_node_port_creator_cb(*args, **kwargs):
125125
assert await node_ports.get(port.key) == port.value
126126
await node_ports.set(port.key, port.value)
127127

128+
# test batch add
129+
await node_ports.set_multiple(
130+
{
131+
port.key: port.value
132+
for port in list(original_inputs.values()) + list(original_outputs.values())
133+
}
134+
)
135+
128136

129137
@pytest.fixture(scope="session")
130138
def e_tag() -> str:

0 commit comments

Comments
 (0)