Skip to content

Commit 9fc6121

Browse files
authored
fix: validate domain is present when calling set_provider on registry (#561)
* fix: Validate domain is present when calling `set_provider` on registry It is incorrect to call `ProviderRegister.set_provider` without a provider AND a domain. A validation check exists for the provider, but none for the domain. In this commit, we introduce that domain validation and introduce tests to capture this expected behavior. Signed-off-by: Matthew Keeler <[email protected]>
1 parent 7c27c7a commit 9fc6121

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

openfeature/provider/_registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def __init__(self) -> None:
2424
def set_provider(self, domain: str, provider: FeatureProvider) -> None:
2525
if provider is None:
2626
raise GeneralError(error_message="No provider")
27+
if domain is None:
28+
raise GeneralError(error_message="No domain")
2729
providers = self._providers
2830
if domain in providers:
2931
old_provider = providers[domain]

tests/provider/test_registry.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from unittest.mock import Mock
2+
3+
import pytest
4+
5+
from openfeature.exception import GeneralError
6+
from openfeature.provider._registry import ProviderRegistry
7+
from openfeature.provider.no_op_provider import NoOpProvider
8+
9+
10+
def test_registry_serves_noop_as_default():
11+
registry = ProviderRegistry()
12+
13+
assert isinstance(registry.get_default_provider(), NoOpProvider)
14+
assert isinstance(registry.get_provider("unknown domain"), NoOpProvider)
15+
16+
17+
def test_setting_provider_requires_domain():
18+
registry = ProviderRegistry()
19+
20+
with pytest.raises(GeneralError) as exc_info:
21+
registry.set_provider(None, NoOpProvider()) # type: ignore[reportArgumentType]
22+
23+
assert exc_info.value.error_message == "No domain"
24+
25+
26+
def test_setting_provider_requires_provider():
27+
registry = ProviderRegistry()
28+
29+
with pytest.raises(GeneralError) as exc_info:
30+
registry.set_provider("domain", None) # type: ignore[reportArgumentType]
31+
32+
assert exc_info.value.error_message == "No provider"
33+
34+
35+
def test_can_register_provider_to_multiple_domains():
36+
registry = ProviderRegistry()
37+
provider = NoOpProvider()
38+
39+
registry.set_provider("domain1", provider)
40+
registry.set_provider("domain2", provider)
41+
42+
assert registry.get_provider("domain1") is provider
43+
assert registry.get_provider("domain2") is provider
44+
45+
46+
def test_registering_provider_replaces_previous_provider():
47+
"""Test that registering a provider replaces the previous provider and calls shutdown on the old one."""
48+
49+
registry = ProviderRegistry()
50+
provider1 = Mock()
51+
provider2 = Mock()
52+
53+
registry.set_provider("domain", provider1)
54+
assert registry.get_provider("domain") is provider1
55+
56+
registry.set_provider("domain", provider2)
57+
assert registry.get_provider("domain") is provider2
58+
59+
provider1.shutdown.assert_called_once()
60+
provider2.shutdown.assert_not_called()
61+
62+
63+
def test_registering_provider_for_first_time_initializes_it():
64+
"""Test that registering a provider for the first time calls its initialize method."""
65+
66+
registry = ProviderRegistry()
67+
provider = Mock()
68+
69+
registry.set_provider("domain1", provider)
70+
registry.set_provider("domain2", provider)
71+
72+
provider.initialize.assert_called_once()
73+
74+
75+
def test_setting_default_provider_requires_provider():
76+
registry = ProviderRegistry()
77+
78+
with pytest.raises(GeneralError) as exc_info:
79+
registry.set_default_provider(None) # type: ignore[reportArgumentType]
80+
81+
assert exc_info.value.error_message == "No provider"
82+
83+
84+
def test_replacing_default_provider_shuts_down_old_one():
85+
"""Test that replacing the default provider shuts down the old default provider."""
86+
87+
registry = ProviderRegistry()
88+
default_provider1 = Mock()
89+
default_provider2 = Mock()
90+
91+
registry.set_default_provider(default_provider1)
92+
assert registry.get_default_provider() is default_provider1
93+
94+
registry.set_default_provider(default_provider2)
95+
assert registry.get_default_provider() is default_provider2
96+
97+
default_provider1.shutdown.assert_called_once()
98+
default_provider2.shutdown.assert_not_called()
99+
100+
101+
def test_setting_default_provider_initializes_it():
102+
registry = ProviderRegistry()
103+
provider = Mock()
104+
105+
registry.set_default_provider(provider)
106+
107+
provider.initialize.assert_called_once()

0 commit comments

Comments
 (0)