Skip to content

Commit 02ea00b

Browse files
madhav165monshri
authored andcommitted
Fixes OAuth after addition of signature to state (IBM#1097)
* copied from main Signed-off-by: Madhav Kandukuri <[email protected]> * testing changes Signed-off-by: Madhav Kandukuri <[email protected]> * Fix oauth code Signed-off-by: Madhav Kandukuri <[email protected]> * Fix tests in test_oauth_router Signed-off-by: Madhav Kandukuri <[email protected]> * Linting fixes Signed-off-by: Madhav Kandukuri <[email protected]> * remove debug_team_dropdown.md Signed-off-by: Madhav Kandukuri <[email protected]> * String issue fixed Signed-off-by: Madhav Kandukuri <[email protected]> --------- Signed-off-by: Madhav Kandukuri <[email protected]>
1 parent b8c1444 commit 02ea00b

File tree

4 files changed

+85
-20
lines changed

4 files changed

+85
-20
lines changed

mcpgateway/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ async def get_current_user(credentials: Optional[HTTPAuthorizationCredentials] =
6767
logger = logging.getLogger(__name__)
6868

6969
if not credentials:
70-
logger.debug("No credentials provided")
70+
logger.warning("No credentials provided")
7171
raise HTTPException(
7272
status_code=status.HTTP_401_UNAUTHORIZED,
7373
detail="Authentication required",

mcpgateway/routers/oauth_router.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from sqlalchemy.orm import Session
2424

2525
# First-Party
26-
from mcpgateway.auth import get_current_user
2726
from mcpgateway.db import Gateway, get_db
27+
from mcpgateway.middleware.rbac import get_current_user_with_permissions
2828
from mcpgateway.schemas import EmailUserResponse
2929
from mcpgateway.services.oauth_manager import OAuthError, OAuthManager
3030
from mcpgateway.services.token_storage_service import TokenStorageService
@@ -35,7 +35,7 @@
3535

3636

3737
@oauth_router.get("/authorize/{gateway_id}")
38-
async def initiate_oauth_flow(gateway_id: str, request: Request, current_user: EmailUserResponse = Depends(get_current_user), db: Session = Depends(get_db)) -> RedirectResponse:
38+
async def initiate_oauth_flow(gateway_id: str, request: Request, current_user: EmailUserResponse = Depends(get_current_user_with_permissions), db: Session = Depends(get_db)) -> RedirectResponse:
3939
"""Initiates the OAuth 2.0 Authorization Code flow for a specified gateway.
4040
4141
This endpoint retrieves the OAuth configuration for the given gateway, validates that
@@ -75,9 +75,9 @@ async def initiate_oauth_flow(gateway_id: str, request: Request, current_user: E
7575

7676
# Initiate OAuth flow with user context
7777
oauth_manager = OAuthManager(token_storage=TokenStorageService(db))
78-
auth_data = await oauth_manager.initiate_authorization_code_flow(gateway_id, gateway.oauth_config, app_user_email=current_user.email)
78+
auth_data = await oauth_manager.initiate_authorization_code_flow(gateway_id, gateway.oauth_config, app_user_email=current_user.get("email"))
7979

80-
logger.info(f"Initiated OAuth flow for gateway {gateway_id} by user {current_user.email}")
80+
logger.info(f"Initiated OAuth flow for gateway {gateway_id} by user {current_user.get('email')}")
8181

8282
# Redirect user to OAuth provider
8383
return RedirectResponse(url=auth_data["authorization_url"])
@@ -132,8 +132,22 @@ async def oauth_callback(
132132
import json
133133

134134
try:
135-
state_decoded = base64.urlsafe_b64decode(state.encode()).decode()
136-
state_data = json.loads(state_decoded)
135+
# Expect state as base64url(payload || signature) where the last 32 bytes
136+
# are the signature. Decode to bytes first so we can split payload vs sig.
137+
state_raw = base64.urlsafe_b64decode(state.encode())
138+
if len(state_raw) <= 32:
139+
raise ValueError("State too short to contain payload and signature")
140+
141+
# Split payload and signature. Signature is the last 32 bytes.
142+
payload_bytes = state_raw[:-32]
143+
# signature_bytes = state_raw[-32:]
144+
145+
# Parse the JSON payload only (not including signature bytes)
146+
try:
147+
state_data = json.loads(payload_bytes.decode())
148+
except Exception as decode_exc:
149+
raise ValueError(f"Failed to parse state payload JSON: {decode_exc}")
150+
137151
gateway_id = state_data.get("gateway_id")
138152
if not gateway_id:
139153
raise ValueError("No gateway_id in state")
@@ -403,7 +417,7 @@ async def get_oauth_status(gateway_id: str, db: Session = Depends(get_db)) -> di
403417

404418

405419
@oauth_router.post("/fetch-tools/{gateway_id}")
406-
async def fetch_tools_after_oauth(gateway_id: str, current_user: EmailUserResponse = Depends(get_current_user), db: Session = Depends(get_db)) -> Dict[str, Any]:
420+
async def fetch_tools_after_oauth(gateway_id: str, current_user: EmailUserResponse = Depends(get_current_user_with_permissions), db: Session = Depends(get_db)) -> Dict[str, Any]:
407421
"""Fetch tools from MCP server after OAuth completion for Authorization Code flow.
408422
409423
Args:
@@ -422,7 +436,7 @@ async def fetch_tools_after_oauth(gateway_id: str, current_user: EmailUserRespon
422436
from mcpgateway.services.gateway_service import GatewayService
423437

424438
gateway_service = GatewayService()
425-
result = await gateway_service.fetch_tools_after_oauth(db, gateway_id, current_user.email)
439+
result = await gateway_service.fetch_tools_after_oauth(db, gateway_id, current_user.get("email"))
426440
tools_count = len(result.get("tools", []))
427441

428442
return {"success": True, "message": f"Successfully fetched and created {tools_count} tools"}

mcpgateway/services/oauth_manager.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -604,8 +604,20 @@ async def _validate_authorization_state(self, gateway_id: str, state: str) -> bo
604604

605605
state_data = json.loads(state_json)
606606

607+
# Parse expires_at as timezone-aware datetime. If the stored value
608+
# is naive, assume UTC for compatibility.
609+
try:
610+
expires_at = datetime.fromisoformat(state_data["expires_at"])
611+
except Exception:
612+
# Fallback: try parsing without microseconds/offsets
613+
expires_at = datetime.strptime(state_data["expires_at"], "%Y-%m-%dT%H:%M:%S")
614+
615+
if expires_at.tzinfo is None:
616+
# Assume UTC for naive timestamps
617+
expires_at = expires_at.replace(tzinfo=timezone.utc)
618+
607619
# Check if state has expired
608-
if datetime.fromisoformat(state_data["expires_at"]) < datetime.now(timezone.utc):
620+
if expires_at < datetime.now(timezone.utc):
609621
logger.warning(f"State has expired for gateway {gateway_id}")
610622
return False
611623

@@ -636,7 +648,12 @@ async def _validate_authorization_state(self, gateway_id: str, state: str) -> bo
636648
return False
637649

638650
# Check if state has expired
639-
if oauth_state.expires_at < datetime.now(timezone.utc):
651+
# Ensure oauth_state.expires_at is timezone-aware. If naive, assume UTC.
652+
expires_at = oauth_state.expires_at
653+
if expires_at.tzinfo is None:
654+
expires_at = expires_at.replace(tzinfo=timezone.utc)
655+
656+
if expires_at < datetime.now(timezone.utc):
640657
logger.warning(f"State has expired for gateway {gateway_id}")
641658
db.delete(oauth_state)
642659
db.commit()
@@ -667,8 +684,12 @@ async def _validate_authorization_state(self, gateway_id: str, state: str) -> bo
667684
logger.warning(f"State not found in memory for gateway {gateway_id}")
668685
return False
669686

670-
# Check if state has expired
671-
if datetime.fromisoformat(state_data["expires_at"]) < datetime.now(timezone.utc):
687+
# Parse and normalize expires_at to timezone-aware datetime
688+
expires_at = datetime.fromisoformat(state_data["expires_at"])
689+
if expires_at.tzinfo is None:
690+
expires_at = expires_at.replace(tzinfo=timezone.utc)
691+
692+
if expires_at < datetime.now(timezone.utc):
672693
logger.warning(f"State has expired for gateway {gateway_id}")
673694
del _oauth_states[state_key] # Clean up expired state
674695
return False

tests/unit/mcpgateway/routers/test_oauth_router.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def mock_gateway(self):
6666
def mock_current_user(self):
6767
"""Create mock current user."""
6868
user = Mock(spec=EmailUserResponse)
69+
user.get = Mock(return_value="[email protected]")
6970
user.email = "[email protected]"
7071
user.full_name = "Test User"
7172
user.is_active = True
@@ -106,7 +107,7 @@ async def test_initiate_oauth_flow_success(self, mock_db, mock_request, mock_gat
106107

107108
mock_oauth_manager_class.assert_called_once_with(token_storage=mock_token_storage)
108109
mock_oauth_manager.initiate_authorization_code_flow.assert_called_once_with(
109-
"gateway123", mock_gateway.oauth_config, app_user_email="[email protected]"
110+
"gateway123", mock_gateway.oauth_config, app_user_email=mock_current_user.get("email")
110111
)
111112

112113
@pytest.mark.asyncio
@@ -194,9 +195,11 @@ async def test_oauth_callback_success(self, mock_db, mock_request, mock_gateway)
194195
import base64
195196
import json
196197

197-
# Setup state with new format
198+
# Setup state with new format (payload + 32-byte signature)
198199
state_data = {"gateway_id": "gateway123", "app_user_email": "[email protected]", "nonce": "abc123"}
199-
state = base64.urlsafe_b64encode(json.dumps(state_data).encode()).decode()
200+
payload = json.dumps(state_data).encode()
201+
signature = b'x' * 32 # Mock 32-byte signature
202+
state = base64.urlsafe_b64encode(payload + signature).decode()
200203

201204
mock_db.execute.return_value.scalar_one_or_none.return_value = mock_gateway
202205

@@ -266,6 +269,27 @@ async def test_oauth_callback_invalid_state(self, mock_db, mock_request):
266269
assert result.status_code == 400
267270
assert "Invalid state parameter" in result.body.decode()
268271

272+
@pytest.mark.asyncio
273+
async def test_oauth_callback_state_too_short(self, mock_db, mock_request):
274+
"""Test OAuth callback with state that's too short to contain signature."""
275+
# Standard
276+
import base64
277+
278+
# Setup - create state with less than 32 bytes total
279+
short_payload = b"short"
280+
state = base64.urlsafe_b64encode(short_payload).decode()
281+
282+
# First-Party
283+
from mcpgateway.routers.oauth_router import oauth_callback
284+
285+
# Execute
286+
result = await oauth_callback(code="auth_code_123", state=state, request=mock_request, db=mock_db)
287+
288+
# Assert
289+
assert isinstance(result, HTMLResponse)
290+
assert result.status_code == 400
291+
assert "Invalid state parameter" in result.body.decode()
292+
269293
@pytest.mark.asyncio
270294
async def test_oauth_callback_gateway_not_found(self, mock_db, mock_request):
271295
"""Test OAuth callback when gateway is not found."""
@@ -275,7 +299,9 @@ async def test_oauth_callback_gateway_not_found(self, mock_db, mock_request):
275299

276300
# Setup
277301
state_data = {"gateway_id": "nonexistent", "app_user_email": "[email protected]"}
278-
state = base64.urlsafe_b64encode(json.dumps(state_data).encode()).decode()
302+
payload = json.dumps(state_data).encode()
303+
signature = b'x' * 32 # Mock 32-byte signature
304+
state = base64.urlsafe_b64encode(payload + signature).decode()
279305

280306
mock_db.execute.return_value.scalar_one_or_none.return_value = None
281307

@@ -299,7 +325,9 @@ async def test_oauth_callback_no_oauth_config(self, mock_db, mock_request):
299325

300326
# Setup
301327
state_data = {"gateway_id": "gateway123", "app_user_email": "[email protected]"}
302-
state = base64.urlsafe_b64encode(json.dumps(state_data).encode()).decode()
328+
payload = json.dumps(state_data).encode()
329+
signature = b'x' * 32 # Mock 32-byte signature
330+
state = base64.urlsafe_b64encode(payload + signature).decode()
303331

304332
mock_gateway = Mock(spec=Gateway)
305333
mock_gateway.id = "gateway123"
@@ -326,7 +354,9 @@ async def test_oauth_callback_oauth_error(self, mock_db, mock_request, mock_gate
326354

327355
# Setup
328356
state_data = {"gateway_id": "gateway123", "app_user_email": "[email protected]"}
329-
state = base64.urlsafe_b64encode(json.dumps(state_data).encode()).decode()
357+
payload = json.dumps(state_data).encode()
358+
signature = b'x' * 32 # Mock 32-byte signature
359+
state = base64.urlsafe_b64encode(payload + signature).decode()
330360

331361
mock_db.execute.return_value.scalar_one_or_none.return_value = mock_gateway
332362

@@ -412,7 +442,7 @@ async def test_fetch_tools_after_oauth_success(self, mock_db, mock_current_user)
412442
# Assert
413443
assert result["success"] is True
414444
assert "Successfully fetched and created 3 tools" in result["message"]
415-
mock_gateway_service.fetch_tools_after_oauth.assert_called_once_with(mock_db, "gateway123", "[email protected]")
445+
mock_gateway_service.fetch_tools_after_oauth.assert_called_once_with(mock_db, "gateway123", mock_current_user.get("email"))
416446

417447
@pytest.mark.asyncio
418448
async def test_fetch_tools_after_oauth_no_tools(self, mock_db, mock_current_user):

0 commit comments

Comments
 (0)