Skip to content

Commit a5e04a7

Browse files
authored
Merge pull request #632 from Nayana-R-Gowda/425-SECURITYFEATURE
make test smoketest doctest passed
2 parents c0c12f4 + da144b7 commit a5e04a7

File tree

5 files changed

+80
-11
lines changed

5 files changed

+80
-11
lines changed

.env.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ JWT_ALGORITHM=HS256
9797
# Expiry time for generated JWT tokens (in minutes; e.g. 7 days)
9898
TOKEN_EXPIRY=10080
9999

100+
# Require all JWT tokens to have expiration claims (true or false)
101+
REQUIRE_TOKEN_EXPIRATION=false
102+
100103
# Used to derive an AES encryption key for secure auth storage
101104
# Must be a non-empty string (e.g. passphrase or random secret)
102105
AUTH_ENCRYPTION_SECRET=my-test-salt

mcpgateway/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
import jq
6161
from jsonpath_ng.ext import parse
6262
from jsonpath_ng.jsonpath import JSONPath
63-
from pydantic import field_validator
63+
from pydantic import Field, field_validator
6464
from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict
6565

6666
logging.basicConfig(
@@ -116,6 +116,8 @@ class Settings(BaseSettings):
116116
auth_required: bool = True
117117
token_expiry: int = 10080 # minutes
118118

119+
require_token_expiration: bool = Field(default=False, description="Require all JWT tokens to have expiration claims") # Default to flexible mode for backward compatibility
120+
119121
# Encryption key phrase for auth storage
120122
auth_encryption_secret: str = "my-test-salt"
121123

mcpgateway/utils/create_jwt_token.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@ def _create_jwt_token(
105105
if expires_in_minutes > 0:
106106
expire = _dt.datetime.now(_dt.timezone.utc) + _dt.timedelta(minutes=expires_in_minutes)
107107
payload["exp"] = int(expire.timestamp())
108+
else:
109+
# Warn about non-expiring token
110+
print(
111+
"⚠️ WARNING: Creating token without expiration. This is a security risk!\n"
112+
" Consider using --exp with a value > 0 for production use.\n"
113+
" Once JWT API (#425) is available, use it for automatic token renewal.",
114+
file=sys.stderr,
115+
)
108116
return jwt.encode(payload, secret, algorithm=algorithm)
109117

110118

mcpgateway/utils/verify_credentials.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
... basic_auth_user = 'user'
1818
... basic_auth_password = 'pass'
1919
... auth_required = True
20+
... require_token_expiration = False
2021
>>> vc.settings = DummySettings()
2122
>>> import jwt
2223
>>> token = jwt.encode({'sub': 'alice'}, 'secret', algorithm='HS256')
@@ -39,6 +40,7 @@
3940
"""
4041

4142
# Standard
43+
import logging
4244
from typing import Optional
4345

4446
# Third-Party
@@ -51,14 +53,16 @@
5153
)
5254
from fastapi.security.utils import get_authorization_scheme_param
5355
import jwt
54-
from jwt import PyJWTError
5556

5657
# First-Party
5758
from mcpgateway.config import settings
5859

5960
basic_security = HTTPBasic(auto_error=False)
6061
security = HTTPBearer(auto_error=False)
6162

63+
# Standard
64+
logger = logging.getLogger(__name__)
65+
6266

6367
async def verify_jwt_token(token: str) -> dict:
6468
"""Verify and decode a JWT token.
@@ -74,6 +78,7 @@ async def verify_jwt_token(token: str) -> dict:
7478
7579
Raises:
7680
HTTPException: 401 status if the token has expired or is invalid.
81+
MissingRequiredClaimError: If the 'exp' claim is required but missing.
7782
7883
Examples:
7984
>>> from mcpgateway.utils import verify_credentials as vc
@@ -83,6 +88,7 @@ async def verify_jwt_token(token: str) -> dict:
8388
... basic_auth_user = 'user'
8489
... basic_auth_password = 'pass'
8590
... auth_required = True
91+
... require_token_expiration = False
8692
>>> vc.settings = DummySettings()
8793
>>> import jwt
8894
>>> token = jwt.encode({'sub': 'alice'}, 'secret', algorithm='HS256')
@@ -108,22 +114,60 @@ async def verify_jwt_token(token: str) -> dict:
108114
... print(e.status_code, e.detail)
109115
401 Invalid token
110116
"""
117+
# try:
118+
# Decode and validate token
119+
# payload = jwt.decode(
120+
# token,
121+
# settings.jwt_secret_key,
122+
# algorithms=[settings.jwt_algorithm],
123+
# # options={"require": ["exp"]}, # Require expiration
124+
# )
125+
# return payload # Contains the claims (e.g., user info)
126+
# except jwt.ExpiredSignatureError:
127+
# raise HTTPException(
128+
# status_code=status.HTTP_401_UNAUTHORIZED,
129+
# detail="Token has expired",
130+
# headers={"WWW-Authenticate": "Bearer"},
131+
# )
132+
# except PyJWTError:
133+
# raise HTTPException(
134+
# status_code=status.HTTP_401_UNAUTHORIZED,
135+
# detail="Invalid token",
136+
# headers={"WWW-Authenticate": "Bearer"},
137+
# )
111138
try:
112-
# Decode and validate token
113-
payload = jwt.decode(
114-
token,
115-
settings.jwt_secret_key,
116-
algorithms=[settings.jwt_algorithm],
117-
# options={"require": ["exp"]}, # Require expiration
139+
# First decode to check claims
140+
unverified = jwt.decode(token, options={"verify_signature": False})
141+
142+
# Check for expiration claim
143+
if "exp" not in unverified and settings.require_token_expiration:
144+
raise jwt.MissingRequiredClaimError("exp")
145+
146+
# Log warning for non-expiring tokens
147+
if "exp" not in unverified:
148+
logger.warning("JWT token without expiration accepted. " "Consider enabling REQUIRE_TOKEN_EXPIRATION for better security. " f"Token sub: {unverified.get('sub', 'unknown')}")
149+
150+
# Full validation
151+
options = {}
152+
if settings.require_token_expiration:
153+
options["require"] = ["exp"]
154+
155+
payload = jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm], options=options)
156+
return payload
157+
158+
except jwt.MissingRequiredClaimError:
159+
raise HTTPException(
160+
status_code=status.HTTP_401_UNAUTHORIZED,
161+
detail="Token is missing required expiration claim. Set REQUIRE_TOKEN_EXPIRATION=false to allow.",
162+
headers={"WWW-Authenticate": "Bearer"},
118163
)
119-
return payload # Contains the claims (e.g., user info)
120164
except jwt.ExpiredSignatureError:
121165
raise HTTPException(
122166
status_code=status.HTTP_401_UNAUTHORIZED,
123167
detail="Token has expired",
124168
headers={"WWW-Authenticate": "Bearer"},
125169
)
126-
except PyJWTError:
170+
except jwt.PyJWTError:
127171
raise HTTPException(
128172
status_code=status.HTTP_401_UNAUTHORIZED,
129173
detail="Invalid token",
@@ -154,6 +198,7 @@ async def verify_credentials(token: str) -> dict:
154198
... basic_auth_user = 'user'
155199
... basic_auth_password = 'pass'
156200
... auth_required = True
201+
... require_token_expiration = False
157202
>>> vc.settings = DummySettings()
158203
>>> import jwt
159204
>>> token = jwt.encode({'sub': 'alice'}, 'secret', algorithm='HS256')
@@ -194,6 +239,7 @@ async def require_auth(credentials: Optional[HTTPAuthorizationCredentials] = Dep
194239
... basic_auth_user = 'user'
195240
... basic_auth_password = 'pass'
196241
... auth_required = True
242+
... require_token_expiration = False
197243
>>> vc.settings = DummySettings()
198244
>>> import jwt
199245
>>> from fastapi.security import HTTPAuthorizationCredentials
@@ -373,6 +419,7 @@ async def require_auth_override(
373419
... basic_auth_user = 'user'
374420
... basic_auth_password = 'pass'
375421
... auth_required = True
422+
... require_token_expiration = False
376423
>>> vc.settings = DummySettings()
377424
>>> import jwt
378425
>>> import asyncio

tests/unit/mcpgateway/utils/test_verify_credentials.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,15 @@
4141
ALGO = "HS256"
4242

4343

44-
def _token(payload: dict, *, exp_delta: int | None = None, secret: str = SECRET) -> str:
44+
# def _token(payload: dict, *, exp_delta: int | None = None, secret: str = SECRET) -> str:
45+
# """Return a signed JWT with optional expiry offset (minutes)."""
46+
# if exp_delta is not None:
47+
# expire = datetime.now(timezone.utc) + timedelta(minutes=exp_delta)
48+
# payload = payload | {"exp": int(expire.timestamp())}
49+
# return jwt.encode(payload, secret, algorithm=ALGO)
50+
51+
52+
def _token(payload: dict, *, exp_delta: int | None = 60, secret: str = SECRET) -> str:
4553
"""Return a signed JWT with optional expiry offset (minutes)."""
4654
if exp_delta is not None:
4755
expire = datetime.now(timezone.utc) + timedelta(minutes=exp_delta)
@@ -56,6 +64,7 @@ def _token(payload: dict, *, exp_delta: int | None = None, secret: str = SECRET)
5664
async def test_verify_jwt_token_success(monkeypatch):
5765
monkeypatch.setattr(vc.settings, "jwt_secret_key", SECRET, raising=False)
5866
monkeypatch.setattr(vc.settings, "jwt_algorithm", ALGO, raising=False)
67+
monkeypatch.setattr(vc.settings, "require_token_expiration", False, raising=False)
5968

6069
token = _token({"sub": "abc"})
6170
data = await vc.verify_jwt_token(token)

0 commit comments

Comments
 (0)