diff --git a/providers/openfeature-provider-flagd/openfeature/schemas b/providers/openfeature-provider-flagd/openfeature/schemas index 76d611fd..2852d777 160000 --- a/providers/openfeature-provider-flagd/openfeature/schemas +++ b/providers/openfeature-provider-flagd/openfeature/schemas @@ -1 +1 @@ -Subproject commit 76d611fd94689d906af316105ac12670d40f7648 +Subproject commit 2852d7772e6b8674681a6ee6b88db10dbe3f6899 diff --git a/providers/openfeature-provider-flagd/openfeature/test-harness b/providers/openfeature-provider-flagd/openfeature/test-harness index 59c3c3cc..fe68e031 160000 --- a/providers/openfeature-provider-flagd/openfeature/test-harness +++ b/providers/openfeature-provider-flagd/openfeature/test-harness @@ -1 +1 @@ -Subproject commit 59c3c3ccfb018db82281684d231067e332c8103d +Subproject commit fe68e0310fd817a8f9bc1e2559f2277fed3aed34 diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py index 34eb1c1c..8a0184e3 100644 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/resolvers/process/connector/grpc_watcher.py @@ -6,7 +6,7 @@ import grpc from google.protobuf.json_format import MessageToDict -from google.protobuf.struct_pb2 import Struct +from grpc import StatusCode from openfeature.evaluation_context import EvaluationContext from openfeature.event import ProviderEventDetails @@ -210,6 +210,23 @@ def _create_request_args(self) -> dict: return request_args + def _fetch_metadata(self) -> typing.Optional[sync_pb2.GetMetadataResponse]: + if self.config.sync_metadata_disabled: + return None + + context_values_request = sync_pb2.GetMetadataRequest() + try: + context_values_response: sync_pb2.GetMetadataResponse = ( + self.stub.GetMetadata(context_values_request, wait_for_ready=True) + ) + return context_values_response + except grpc.RpcError as e: + if e.code() == StatusCode.UNIMPLEMENTED: + logger.debug(f"Error getting sync metadata: {e}") + return None + else: + raise e + def listen(self) -> None: call_args = ( {"timeout": self.streamline_deadline_seconds} @@ -220,18 +237,7 @@ def listen(self) -> None: while self.active: try: - context_values_response: sync_pb2.GetMetadataResponse - if self.config.sync_metadata_disabled: - context_values_response = sync_pb2.GetMetadataResponse( - metadata=Struct() - ) - else: - context_values_request = sync_pb2.GetMetadataRequest() - context_values_response = self.stub.GetMetadata( - context_values_request, wait_for_ready=True - ) - - context_values = MessageToDict(context_values_response) + context_values_response = self._fetch_metadata() request = sync_pb2.SyncFlagsRequest(**request_args) @@ -245,12 +251,20 @@ def listen(self) -> None: ) self.flag_store.update(json.loads(flag_str)) + context_values = {} + if flag_rsp.sync_context: + context_values = MessageToDict(flag_rsp.sync_context) + elif context_values_response: + context_values = MessageToDict(context_values_response)[ + "metadata" + ] + if not self.connected: self.emit_provider_ready( ProviderEventDetails( message="gRPC sync connection established" ), - context_values["metadata"], + context_values, ) self.connected = True diff --git a/providers/openfeature-provider-flagd/tests/e2e/step/provider_steps.py b/providers/openfeature-provider-flagd/tests/e2e/step/provider_steps.py index 3d8d5195..a2589d12 100644 --- a/providers/openfeature-provider-flagd/tests/e2e/step/provider_steps.py +++ b/providers/openfeature-provider-flagd/tests/e2e/step/provider_steps.py @@ -32,6 +32,7 @@ class TestProviderType(Enum): SSL = "ssl" SOCKET = "socket" METADATA = "metadata" + SYNCPAYLOAD = "syncpayload" @given("a provider is registered", target_fixture="client") @@ -71,6 +72,8 @@ def get_default_options_for_provider( return options, True elif t == TestProviderType.METADATA: launchpad = "metadata" + elif t == TestProviderType.SYNCPAYLOAD: + launchpad = "sync-payload" if resolver_type == ResolverType.FILE: if "selector" in option_values: diff --git a/providers/openfeature-provider-flagd/tests/test_errors.py b/providers/openfeature-provider-flagd/tests/test_errors.py index dc5fded2..e64ca376 100644 --- a/providers/openfeature-provider-flagd/tests/test_errors.py +++ b/providers/openfeature-provider-flagd/tests/test_errors.py @@ -138,5 +138,5 @@ def fail(*args, **kwargs): ) elapsed = time.time() - t - assert abs(elapsed - wait * 0.001) < 0.15 + assert abs(elapsed - wait * 0.001) < 0.17 assert init_failed diff --git a/providers/openfeature-provider-flagd/tests/test_grpc_watcher.py b/providers/openfeature-provider-flagd/tests/test_grpc_watcher.py new file mode 100644 index 00000000..ad7a8015 --- /dev/null +++ b/providers/openfeature-provider-flagd/tests/test_grpc_watcher.py @@ -0,0 +1,135 @@ +import threading +import time +import typing +import unittest +from unittest.mock import MagicMock, Mock, patch + +from google.protobuf.json_format import MessageToDict +from google.protobuf.struct_pb2 import Struct +from grpc import Channel + +from openfeature.contrib.provider.flagd.config import Config +from openfeature.contrib.provider.flagd.resolvers.process.connector.grpc_watcher import ( + GrpcWatcher, +) +from openfeature.contrib.provider.flagd.resolvers.process.flags import FlagStore +from openfeature.event import ProviderEventDetails +from openfeature.schemas.protobuf.flagd.sync.v1.sync_pb2 import ( + GetMetadataResponse, + SyncFlagsResponse, +) +from openfeature.schemas.protobuf.flagd.sync.v1.sync_pb2_grpc import FlagSyncServiceStub + + +class TestGrpcWatcher(unittest.TestCase): + def setUp(self): + config = Mock(spec=Config) + config.retry_backoff_ms = 1000 + config.retry_backoff_max_ms = 5000 + config.retry_grace_period = 5 + config.stream_deadline_ms = 1000 + config.deadline_ms = 5000 + config.selector = None + config.provider_id = None + config.tls = False + config.cert_path = None + config.channel_credentials = None + config.host = "localhost" + config.port = 5000 + config.sync_metadata_disabled = False + + flag_store = Mock(spec=FlagStore) + flag_store.update.return_value = None + emit_provider_error = Mock() + emit_provider_stale = Mock() + channel = Mock(spec=Channel) + self.provider_done = False + self.provider_details: typing.Optional[ProviderEventDetails] = None + self.context: typing.Optional[dict] = None + + with patch( + "openfeature.contrib.provider.flagd.resolvers.process.connector.grpc_watcher.GrpcWatcher._generate_channel", + return_value=channel, + ): + self.grpc_watcher = GrpcWatcher( + config=config, + flag_store=flag_store, + emit_provider_ready=self.provider_ready, + emit_provider_error=emit_provider_error, + emit_provider_stale=emit_provider_stale, + ) + self.mock_stub = MagicMock(spec=FlagSyncServiceStub) + self.mock_metadata = GetMetadataResponse(metadata={"attribute": "value1"}) + self.mock_stub.GetMetadata = Mock(return_value=self.mock_metadata) + self.grpc_watcher.stub = self.mock_stub + self.grpc_watcher.active = True + + def provider_ready(self, details: ProviderEventDetails, context: dict): + self.provider_done = True + self.provider_details = details + self.context = context + + def run_listen_and_shutdown_after(self): + listener = threading.Thread(target=self.grpc_watcher.listen) + listener.start() + for _i in range(0, 100): + if self.provider_done: + break + time.sleep(0.001) + + self.assertTrue(self.provider_done) + self.grpc_watcher.shutdown() + listener.join(timeout=0.5) + + def test_listen_with_sync_metadata_and_sync_context(self): + sync_context = Struct() + sync_context.update({"attribute": "value"}) + mock_stream_with_sync_context = iter( + [ + SyncFlagsResponse( + flag_configuration='{"flag_key": "flag_value"}', + sync_context=sync_context, + ), + ] + ) + self.mock_stub.SyncFlags = Mock(return_value=mock_stream_with_sync_context) + + self.run_listen_and_shutdown_after() + + self.assertEqual( + self.provider_details.message, "gRPC sync connection established" + ) + self.assertEqual(self.context, MessageToDict(sync_context)) + + def test_listen_with_sync_metadata_only(self): + mock_stream_no_sync_context = iter( + [ + SyncFlagsResponse(flag_configuration='{"flag_key": "flag_value"}'), + ] + ) + self.mock_stub.SyncFlags = Mock(return_value=mock_stream_no_sync_context) + + self.run_listen_and_shutdown_after() + + self.assertEqual( + self.provider_details.message, "gRPC sync connection established" + ) + self.assertEqual(self.context, MessageToDict(self.mock_metadata.metadata)) + + def test_listen_with_sync_metadata_disabled_in_config(self): + self.grpc_watcher.config.sync_metadata_disabled = True + mock_stream_no_sync_context = iter( + [ + SyncFlagsResponse(flag_configuration='{"flag_key": "flag_value"}'), + ] + ) + self.mock_stub.SyncFlags = Mock(return_value=mock_stream_no_sync_context) + + self.run_listen_and_shutdown_after() + + self.mock_stub.GetMetadata.assert_not_called() + + self.assertEqual( + self.provider_details.message, "gRPC sync connection established" + ) + self.assertEqual(self.context, {})