diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/flow_state.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/flow_state.py index efeb7cb2..369e29ec 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/flow_state.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/flow_state.py @@ -33,6 +33,7 @@ class FlowErrorTag(Enum): NONE = "none" MAGIC_FORMAT = "magic_format" MAGIC_CODE_INCORRECT = "magic_code_incorrect" + DUPLICATE_EXCHANGE = "duplicate_exchange" OTHER = "other" diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/oauth_flow.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/oauth_flow.py index 3a12b890..dff824f8 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/oauth_flow.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/oauth/oauth_flow.py @@ -4,10 +4,11 @@ from __future__ import annotations import logging +import time from pydantic import BaseModel from datetime import datetime -from typing import Optional +from typing import Optional, Set from microsoft_agents.activity import ( Activity, @@ -104,6 +105,20 @@ def __init__( self._max_attempts, ) + # Public registry of token exchange ids. This set is periodically cleared + # by a background asyncio task to avoid unbounded growth. + self.token_exchange_id_registry: Set[str] = set() + + # Background task for periodically clearing the registry. The task is + # created lazily when an asyncio event loop is running. + self._clear_interval_seconds = kwargs.get( + "token_exchange_registry_clear_interval", 10 + ) + # Track the last time the registry was cleared (epoch seconds). The + # registry will be cleared lazily when an async entrypoint is hit and + # the interval has elapsed. + self._last_registry_clear: float = time.time() + @property def flow_state(self) -> FlowState: return self._flow_state.model_copy() @@ -246,16 +261,26 @@ async def _continue_from_invoke_verify_state( async def _continue_from_invoke_token_exchange( self, activity: Activity - ) -> TokenResponse: + ) -> tuple[TokenResponse, FlowErrorTag]: """Handles the continuation of the flow from an invoke activity for token exchange.""" + logger.info("Continuing OAuth flow with token exchange...") token_exchange_request = activity.value + token_exchange_id = token_exchange_request.get("id") + + if token_exchange_id in self.token_exchange_id_registry: + logger.warning( + "Token exchange request with id %s has already been processed", + token_exchange_id, + ) + return None, FlowErrorTag.DUPLICATE_EXCHANGE + self.token_exchange_id_registry.add(token_exchange_id) token_response = await self._user_token_client.user_token.exchange_token( user_id=self._user_id, connection_name=self._abs_oauth_connection_name, channel_id=self._channel_id, body=token_exchange_request, ) - return token_response + return token_response, FlowErrorTag.NONE async def continue_flow(self, activity: Activity) -> FlowResponse: """Continues the OAuth flow based on the incoming activity. @@ -269,7 +294,24 @@ async def continue_flow(self, activity: Activity) -> FlowResponse: """ logger.debug("Continuing auth flow...") + # Lazily clear the registry if the configured interval has elapsed. + self._maybe_clear_token_exchange_registry() + if not self._flow_state.is_active(): + if ( + activity.type == ActivityTypes.invoke + and activity.name == "signin/tokenExchange" + and activity.value.get("id") in self.token_exchange_id_registry + ): + logger.debug( + "Token exchange request with id %s has already been processed", + activity.value.get("id"), + ) + return FlowResponse( + flow_state=self._flow_state.model_copy(), + token_response=None, + flow_error_tag=FlowErrorTag.DUPLICATE_EXCHANGE, + ) logger.debug("OAuth flow is not active, cannot continue") self._flow_state.tag = FlowStateTag.FAILURE return FlowResponse( @@ -288,14 +330,20 @@ async def continue_flow(self, activity: Activity) -> FlowResponse: activity.type == ActivityTypes.invoke and activity.name == "signin/tokenExchange" ): - token_response = await self._continue_from_invoke_token_exchange(activity) + ( + token_response, + flow_error_tag, + ) = await self._continue_from_invoke_token_exchange(activity) else: raise ValueError(f"Unknown activity type {activity.type}") if not token_response and flow_error_tag == FlowErrorTag.NONE: flow_error_tag = FlowErrorTag.OTHER - if flow_error_tag != FlowErrorTag.NONE: + if ( + flow_error_tag != FlowErrorTag.NONE + and flow_error_tag != FlowErrorTag.DUPLICATE_EXCHANGE + ): logger.debug("Flow error occurred: %s", flow_error_tag) self._flow_state.tag = FlowStateTag.CONTINUE self._use_attempt() @@ -340,3 +388,20 @@ async def begin_or_continue_flow(self, activity: Activity) -> FlowResponse: logger.debug("No active flow, beginning new flow...") return await self.begin_flow(activity) + + def _maybe_clear_token_exchange_registry(self) -> None: + """Clear the `token_exchange_id_registry` if the configured interval + (seconds) has elapsed since the last clear. This uses the machine + epoch (time.time()) and performs lazy eviction when registry access + occurs instead of running a background task. + """ + now = time.time() + + if now - self._last_registry_clear >= self._clear_interval_seconds: + if self.token_exchange_id_registry: + logger.debug( + "Clearing token_exchange_id_registry by epoch check (size=%d)", + len(self.token_exchange_id_registry), + ) + self.token_exchange_id_registry.clear() + self._last_registry_clear = now diff --git a/tests/_common/storage/utils.py b/tests/_common/storage/utils.py index 9a8f95ae..5a8de601 100644 --- a/tests/_common/storage/utils.py +++ b/tests/_common/storage/utils.py @@ -4,13 +4,10 @@ from abc import ABC from typing import Any -from microsoft_agents.hosting.core.storage import ( - Storage, - StoreItem, - MemoryStorage -) +from microsoft_agents.hosting.core.storage import Storage, StoreItem, MemoryStorage from microsoft_agents.hosting.core.storage._type_aliases import JSON + class MockStoreItem(StoreItem): """Test implementation of StoreItem for testing purposes""" diff --git a/tests/hosting_core/test_oauth_flow.py b/tests/hosting_core/test_oauth_flow.py index ea6268b8..32015406 100644 --- a/tests/hosting_core/test_oauth_flow.py +++ b/tests/hosting_core/test_oauth_flow.py @@ -25,7 +25,6 @@ class TestOAuthFlowUtils: - def create_user_token_client(self, mocker, get_token_return=None): user_token_client = mocker.Mock(spec=UserTokenClientBase) @@ -104,7 +103,6 @@ def flow(self, sample_flow_state, user_token_client): class TestOAuthFlow(TestOAuthFlowUtils): - def test_init_no_user_token_client(self, sample_flow_state): with pytest.raises(ValueError): OAuthFlow(sample_flow_state, None) @@ -602,3 +600,82 @@ async def test_begin_or_continue_flow_completed_flow_state(self, mocker): assert actual_response == expected_response OAuthFlow.begin_flow.assert_not_called() OAuthFlow.continue_flow.assert_not_called() + + @pytest.mark.asyncio + async def test_token_exchange_dedupe_prevents_replay( + self, mocker, sample_active_flow_state, user_token_client + ): + # setup + token_exchange_request = {"id": "exchange-1"} + user_token_client.user_token.exchange_token = mocker.AsyncMock( + return_value=TokenResponse(token=RES_TOKEN) + ) + activity = self.create_activity( + mocker, + ActivityTypes.invoke, + name="signin/tokenExchange", + value=token_exchange_request, + ) + + flow = OAuthFlow(sample_active_flow_state, user_token_client) + + # first request should be processed + response1 = await flow.continue_flow(activity) + user_token_client.user_token.exchange_token.assert_called_once_with( + user_id=sample_active_flow_state.user_id, + connection_name=sample_active_flow_state.connection, + channel_id=sample_active_flow_state.channel_id, + body=token_exchange_request, + ) + assert response1.token_response == TokenResponse(token=RES_TOKEN) + # registry should contain the processed id + assert "exchange-1" in flow.token_exchange_id_registry + + # second request with same id should be ignored (no additional call) + response2 = await flow.continue_flow(activity) + # still only called once + assert user_token_client.user_token.exchange_token.call_count == 1 + assert response2.token_response == None + assert response2.flow_error_tag == FlowErrorTag.DUPLICATE_EXCHANGE + + @pytest.mark.asyncio + async def test_token_exchange_registry_clears_after_interval( + self, mocker, sample_active_flow_state, user_token_client + ): + # setup + token_exchange_request = {"id": "exchange-2"} + user_token_client.user_token.exchange_token = mocker.AsyncMock( + return_value=TokenResponse(token=RES_TOKEN) + ) + activity = self.create_activity( + mocker, + ActivityTypes.invoke, + name="signin/tokenExchange", + value=token_exchange_request, + ) + + flow = OAuthFlow(sample_active_flow_state, user_token_client) + + # first request should be processed + response1 = await flow.continue_flow(activity) + assert user_token_client.user_token.exchange_token.call_count == 1 + assert response1.token_response == TokenResponse(token=RES_TOKEN) + # registry should contain the processed id + assert "exchange-2" in flow.token_exchange_id_registry + + # simulate passage of time beyond the clear interval so the registry is cleared + import time as _time + + flow._last_registry_clear = _time.time() - (flow._clear_interval_seconds + 100) + + # explicitly invoke the lazy clear helper to simulate the moment when + # the registry would be cleared and assert it was removed. + flow._maybe_clear_token_exchange_registry() + flow._flow_state.tag = FlowStateTag.CONTINUE # keep it active + assert "exchange-2" not in flow.token_exchange_id_registry + + # second request should now be processed again (registry was lazily cleared) + + response2 = await flow.continue_flow(activity) + assert user_token_client.user_token.exchange_token.call_count == 2 + assert response2.token_response == TokenResponse(token=RES_TOKEN)