From c422feb301c3185ee6af81be1a4179d97e47640d Mon Sep 17 00:00:00 2001 From: Leonard <92177433+LeonardHd@users.noreply.github.com> Date: Tue, 20 Aug 2024 20:21:40 +0000 Subject: [PATCH] fix: entity id should compare based on name and key --- .../models/utils/entity_utils.py | 12 ++++++++++ tests/utils/__init__.py | 1 + tests/utils/test_entity_utils.py | 24 +++++++++++++++++++ 3 files changed, 37 insertions(+) create mode 100644 tests/utils/__init__.py create mode 100644 tests/utils/test_entity_utils.py diff --git a/azure/durable_functions/models/utils/entity_utils.py b/azure/durable_functions/models/utils/entity_utils.py index f5669323..aebdb113 100644 --- a/azure/durable_functions/models/utils/entity_utils.py +++ b/azure/durable_functions/models/utils/entity_utils.py @@ -89,3 +89,15 @@ def __str__(self) -> str: A SchedulerId-based string representation of the EntityId """ return EntityId.get_scheduler_id(entity_id=self) + + def __eq__(self, other: object) -> bool: + """Check if two EntityId objects are equal. + + Parameters + ---------- + other: object + """ + if not isinstance(other, EntityId): + return False + + return self.name == other.name and self.key == other.key diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/utils/test_entity_utils.py b/tests/utils/test_entity_utils.py new file mode 100644 index 00000000..b47a6512 --- /dev/null +++ b/tests/utils/test_entity_utils.py @@ -0,0 +1,24 @@ +import pytest +from azure.durable_functions.models.utils.entity_utils import EntityId + +@pytest.mark.parametrize( + ("name_e1", "key_1", "name_e2", "key_2", "expected"), + [ + ("name1", "key1", "name1", "key1", True), + ("name1", "key1", "name1", "key2", False), + ("name1", "key1", "name2", "key1", False), + ("name1", "key1", "name2", "key2", False), + ], +) +def test_equal_entity_by_name_and_key(name_e1, key_1, name_e2, key_2, expected): + + entity1 = EntityId(name_e1, key_1) + entity2 = EntityId(name_e2, key_2) + + assert (entity1 == entity2) == expected + +def test_equality_with_non_entity_id(): + + entity = EntityId("name", "key") + + assert (entity == "not an entity id") == False