Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import grpc
from google.protobuf.json_format import MessageToDict
from google.protobuf.struct_pb2 import Struct
from grpc import StatusCode

from openfeature.evaluation_context import EvaluationContext
from openfeature.event import ProviderEventDetails
Expand Down Expand Up @@ -210,6 +210,24 @@ def _create_request_args(self) -> dict:

return request_args

def _fetch_metadata(self) -> typing.Optional[sync_pb2.GetMetadataResponse]:
if self.config.sync_metadata_disabled:
return None

context_values_request = sync_pb2.GetMetadataRequest()
context_values_response: sync_pb2.GetMetadataResponse
try:
context_values_response = self.stub.GetMetadata(
context_values_request, wait_for_ready=True
)
return context_values_response
except grpc.RpcError as e:
if e.code() == StatusCode.UNIMPLEMENTED:
logger.debug(f"Error getting sync metadata: {e}")
return None
else:
raise e

def listen(self) -> None:
call_args = (
{"timeout": self.streamline_deadline_seconds}
Expand All @@ -220,18 +238,7 @@ def listen(self) -> None:

while self.active:
try:
context_values_response: sync_pb2.GetMetadataResponse
if self.config.sync_metadata_disabled:
context_values_response = sync_pb2.GetMetadataResponse(
metadata=Struct()
)
else:
context_values_request = sync_pb2.GetMetadataRequest()
context_values_response = self.stub.GetMetadata(
context_values_request, wait_for_ready=True
)

context_values = MessageToDict(context_values_response)
context_values_response = self._fetch_metadata()

request = sync_pb2.SyncFlagsRequest(**request_args)

Expand All @@ -245,12 +252,20 @@ def listen(self) -> None:
)
self.flag_store.update(json.loads(flag_str))

context_values = {}
if flag_rsp.sync_context:
context_values = MessageToDict(flag_rsp.sync_context)
elif context_values_response:
context_values = MessageToDict(context_values_response)[
"metadata"
]

if not self.connected:
self.emit_provider_ready(
ProviderEventDetails(
message="gRPC sync connection established"
),
context_values["metadata"],
context_values,
)
self.connected = True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class TestProviderType(Enum):
SSL = "ssl"
SOCKET = "socket"
METADATA = "metadata"
SYNCPAYLOAD = "syncpayload"


@given("a provider is registered", target_fixture="client")
Expand Down Expand Up @@ -71,6 +72,8 @@ def get_default_options_for_provider(
return options, True
elif t == TestProviderType.METADATA:
launchpad = "metadata"
elif t == TestProviderType.SYNCPAYLOAD:
launchpad = "sync-payload"

if resolver_type == ResolverType.FILE:
if "selector" in option_values:
Expand Down
126 changes: 126 additions & 0 deletions providers/openfeature-provider-flagd/tests/test_grpc_watcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import threading
import time
import unittest
from unittest.mock import MagicMock, Mock, patch

from google.protobuf.json_format import MessageToDict
from google.protobuf.struct_pb2 import Struct
from grpc import Channel

from openfeature.contrib.provider.flagd.config import Config
from openfeature.contrib.provider.flagd.resolvers.process.connector.grpc_watcher import (
GrpcWatcher,
)
from openfeature.contrib.provider.flagd.resolvers.process.flags import FlagStore
from openfeature.event import ProviderEventDetails
from openfeature.schemas.protobuf.flagd.sync.v1.sync_pb2 import (
GetMetadataResponse,
SyncFlagsResponse,
)
from openfeature.schemas.protobuf.flagd.sync.v1.sync_pb2_grpc import FlagSyncServiceStub


class TestGrpcWatcher(unittest.TestCase):
def setUp(self):
config = Mock(spec=Config)
config.retry_backoff_ms = 1000
config.retry_backoff_max_ms = 5000
config.retry_grace_period = 5
config.stream_deadline_ms = 1000
config.deadline_ms = 5000
config.selector = None
config.provider_id = None
config.tls = False
config.cert_path = None
config.channel_credentials = None
config.host = "localhost"
config.port = 5000
config.sync_metadata_disabled = False

flag_store = Mock(spec=FlagStore)
flag_store.update.return_value = None
self.emit_provider_ready = Mock()
emit_provider_error = Mock()
emit_provider_stale = Mock()
channel = Mock(spec=Channel)

with patch(
"openfeature.contrib.provider.flagd.resolvers.process.connector.grpc_watcher.GrpcWatcher._generate_channel",
return_value=channel,
):
self.grpc_watcher = GrpcWatcher(
config=config,
flag_store=flag_store,
emit_provider_ready=self.emit_provider_ready,
emit_provider_error=emit_provider_error,
emit_provider_stale=emit_provider_stale,
)
self.mock_stub = MagicMock(spec=FlagSyncServiceStub)
self.mock_metadata = GetMetadataResponse(metadata={"attribute": "value1"})
self.mock_stub.GetMetadata = Mock(return_value=self.mock_metadata)
self.grpc_watcher.stub = self.mock_stub
self.grpc_watcher.active = True
self.shutdown_thread = lambda: threading.Thread(
target=self.shutdown_after_x_seconds
)

def shutdown_after_x_seconds(self, seconds=1):
time.sleep(seconds)
self.grpc_watcher.shutdown()

def test_listen_with_sync_metadata_and_sync_context(self):
sync_context = Struct()
sync_context.update({"attribute": "value"})
mock_stream_with_sync_context = iter(
[
SyncFlagsResponse(
flag_configuration='{"flag_key": "flag_value"}',
sync_context=sync_context,
),
]
)
self.mock_stub.SyncFlags = Mock(return_value=mock_stream_with_sync_context)

self.shutdown_thread().start()

self.grpc_watcher.listen()

self.emit_provider_ready.assert_called_once_with(
ProviderEventDetails(message="gRPC sync connection established"),
MessageToDict(sync_context),
)

def test_listen_with_sync_metadata_only(self):
mock_stream_no_sync_context = iter(
[
SyncFlagsResponse(flag_configuration='{"flag_key": "flag_value"}'),
]
)
self.mock_stub.SyncFlags = Mock(return_value=mock_stream_no_sync_context)

self.shutdown_thread().start()

self.grpc_watcher.listen()

self.emit_provider_ready.assert_called_once_with(
ProviderEventDetails(message="gRPC sync connection established"),
MessageToDict(self.mock_metadata.metadata),
)

def test_listen_with_sync_metadata_disabled_in_config(self):
self.grpc_watcher.config.sync_metadata_disabled = True
mock_stream_no_sync_context = iter(
[
SyncFlagsResponse(flag_configuration='{"flag_key": "flag_value"}'),
]
)
self.mock_stub.SyncFlags = Mock(return_value=mock_stream_no_sync_context)
self.shutdown_thread().start()

self.grpc_watcher.listen()

self.mock_stub.GetMetadata.assert_not_called()

self.emit_provider_ready.assert_called_once_with(
ProviderEventDetails(message="gRPC sync connection established"), {}
)