Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions durabletask-azuremanaged/durabletask/azuremanaged/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from azure.core.credentials import TokenCredential
from typing import Optional

from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import \
DTSDefaultClientInterceptorImpl
from azure.core.credentials import TokenCredential

from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import (
DTSDefaultClientInterceptorImpl,
)
from durabletask.client import TaskHubGrpcClient


Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import grpc
from importlib.metadata import version
from typing import Optional

import grpc
from azure.core.credentials import TokenCredential

from durabletask.azuremanaged.internal.access_token_manager import \
AccessTokenManager
from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager
from durabletask.internal.grpc_interceptor import (
DefaultClientInterceptorImpl, _ClientCallDetails)
DefaultClientInterceptorImpl,
_ClientCallDetails,
)


class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl):
Expand All @@ -18,7 +20,16 @@ class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl):
interceptor to add additional headers to all calls as needed."""

def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: str):
self._metadata = [("taskhub", taskhub_name)]
try:
# Get the version of the azuremanaged package
sdk_version = version('durabletask-azuremanaged')
except Exception:
# Fallback if version cannot be determined
sdk_version = "unknown"
user_agent = f"durabletask-python/{sdk_version}"
self._metadata = [
("taskhub", taskhub_name),
("x-user-agent", user_agent)] # 'user-agent' is a reserved header in grpc, so we use 'x-user-agent' instead
super().__init__(self._metadata)

if token_credential is not None:
Expand Down
8 changes: 5 additions & 3 deletions durabletask-azuremanaged/durabletask/azuremanaged/worker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from azure.core.credentials import TokenCredential
from typing import Optional

from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import \
DTSDefaultClientInterceptorImpl
from azure.core.credentials import TokenCredential

from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import (
DTSDefaultClientInterceptorImpl,
)
from durabletask.worker import TaskHubGrpcWorker


Expand Down
108 changes: 108 additions & 0 deletions tests/durabletask-azuremanaged/test_durabletask_grpc_interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import threading
import unittest
from concurrent import futures
from importlib.metadata import version

import grpc

from durabletask.azuremanaged.client import DurableTaskSchedulerClient
from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import (
DTSDefaultClientInterceptorImpl,
)
from durabletask.internal import orchestrator_service_pb2 as pb
from durabletask.internal import orchestrator_service_pb2_grpc as stubs


class MockTaskHubSidecarServiceServicer(stubs.TaskHubSidecarServiceServicer):
"""Mock implementation of the TaskHubSidecarService for testing."""

def __init__(self):
self.captured_metadata = {}
self.requests_received = 0

def GetInstance(self, request, context):
"""Implementation of GetInstance that captures the metadata."""
# Store all metadata key-value pairs from the context
for key, value in context.invocation_metadata():
self.captured_metadata[key] = value

self.requests_received += 1

# Return a mock response
response = pb.GetInstanceResponse(exists=False)
return response


class TestDurableTaskGrpcInterceptor(unittest.TestCase):
"""Tests for the DTSDefaultClientInterceptorImpl class."""

@classmethod
def setUpClass(cls):
# Start a real gRPC server on a free port
cls.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
cls.port = cls.server.add_insecure_port('[::]:0') # Bind to a random free port
cls.server_address = f"localhost:{cls.port}"

# Add our mock service implementation to the server
cls.mock_servicer = MockTaskHubSidecarServiceServicer()
stubs.add_TaskHubSidecarServiceServicer_to_server(cls.mock_servicer, cls.server)

# Start the server in a background thread
cls.server.start()

@classmethod
def tearDownClass(cls):
cls.server.stop(grace=None)

def test_user_agent_metadata_passed_in_request(self):
"""Test that the user agent metadata is correctly passed in gRPC requests."""
# Create a client that connects to our mock server
# Note: secure_channel is False and token_credential is None as specified
task_hub_client = DurableTaskSchedulerClient(
host_address=self.server_address,
secure_channel=False,
taskhub="test-taskhub",
token_credential=None
)

# Make a client call that will trigger our interceptor
task_hub_client.get_orchestration_state("test-instance-id")

# Verify the request was received by our mock server
self.assertEqual(1, self.mock_servicer.requests_received, "Expected one request to be received")

# Check if our custom x-user-agent header was correctly set
self.assertIn("x-user-agent", self.mock_servicer.captured_metadata, "x-user-agent header not found")

# Get what we expect our user agent to be
try:
expected_version = version('durabletask-azuremanaged')
except Exception:
expected_version = "unknown"

expected_user_agent = f"durabletask-python/{expected_version}"
self.assertEqual(
expected_user_agent,
self.mock_servicer.captured_metadata["x-user-agent"],
f"Expected x-user-agent header to be '{expected_user_agent}'"
)

# Check if the taskhub header was correctly set
self.assertIn("taskhub", self.mock_servicer.captured_metadata, "taskhub header not found")
self.assertEqual("test-taskhub", self.mock_servicer.captured_metadata["taskhub"])

# Verify the standard gRPC user-agent is different from our custom one
# Note: gRPC automatically adds its own "user-agent" header
self.assertIn("user-agent", self.mock_servicer.captured_metadata, "gRPC user-agent header not found")
self.assertNotEqual(
self.mock_servicer.captured_metadata["user-agent"],
self.mock_servicer.captured_metadata["x-user-agent"],
"gRPC user-agent should be different from our custom x-user-agent"
)


if __name__ == "__main__":
unittest.main()
Loading