Skip to content

Commit 17b80f8

Browse files
committed
fix(svc): create ensure_valid method for regroup entity domain errors, delete empty check for get svc method
1 parent 9b13ede commit 17b80f8

File tree

5 files changed

+43
-44
lines changed

5 files changed

+43
-44
lines changed

src/fastapi_api_key/domain/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,18 @@ def ensure_valid_scopes(self, required_scopes: List[str]) -> None:
103103
InvalidScopes: If the key does not have the required scopes.
104104
"""
105105
...
106+
107+
def ensure_valid(self, scopes: List[str]) -> None:
108+
"""Ensure the API key is valid for authentication and scopes.
109+
110+
This is a convenience method that combines both `ensure_can_authenticate`
111+
and `ensure_valid_scopes`.
112+
113+
Arguments:
114+
scopes (List[str]): List of required scopes to check against the key's scopes.
115+
Raises:
116+
KeyInactive: If the key is disabled.
117+
KeyExpired: If the key is expired.
118+
InvalidScopes: If the key does not have the required scopes.
119+
"""
120+
...

src/fastapi_api_key/domain/entities.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@ def ensure_valid_scopes(self, required_scopes: List[str]) -> None:
174174
if missing_scopes:
175175
raise InvalidScopes(f"API key is missing required scopes: {missing_scopes_str}")
176176

177+
def ensure_valid(self, scopes: List[str]) -> None:
178+
self.ensure_can_authenticate()
179+
self.ensure_valid_scopes(scopes)
180+
177181
def __repr__(self):
178182
return (
179183
f"ApiKey(id_={self.id_!r}, name={self.name!r}, description={self.description!r}, "

src/fastapi_api_key/services/base.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -295,18 +295,15 @@ async def load_dotenv(self, envvar_prefix: str = "API_KEY_"):
295295
raise ConfigurationError(f"No environment variables found with prefix '{envvar_prefix}'")
296296

297297
for key, api_key in zip(list_keys, list_api_key):
298-
global_prefix, key_id, key_secret = self._get_parts(api_key)
298+
parsed = self._get_parts(api_key)
299299

300300
await self.create(
301301
name=key,
302-
key_id=key_id,
303-
key_secret=key_secret,
302+
key_id=parsed.key_id,
303+
key_secret=parsed.key_secret,
304304
)
305305

306306
async def get_by_id(self, id_: str) -> ApiKey:
307-
if id_.strip() == "":
308-
raise KeyNotProvided("No API key provided")
309-
310307
entity = await self._repo.get_by_id(id_)
311308

312309
if entity is None:
@@ -315,9 +312,6 @@ async def get_by_id(self, id_: str) -> ApiKey:
315312
return entity
316313

317314
async def get_by_key_id(self, key_id: str) -> ApiKey:
318-
if not key_id.strip():
319-
raise KeyNotProvided("No API key key_id provided (key_id cannot be empty)")
320-
321315
entity = await self._repo.get_by_key_id(key_id)
322316

323317
if entity is None:
@@ -426,20 +420,10 @@ def _parse_and_validate_key(self, api_key: Optional[str]) -> ParsedApiKey:
426420
KeyNotProvided: If the key is None or empty.
427421
InvalidKey: If the format or prefix is invalid.
428422
"""
429-
if api_key is None:
423+
if api_key is None or api_key.strip() == "":
430424
raise KeyNotProvided("Api key must be provided (not given)")
431425

432-
global_prefix, key_id, key_secret = self._get_parts(api_key)
433-
434-
if global_prefix != self.global_prefix:
435-
raise InvalidKey("Api key is invalid (wrong global prefix)")
436-
437-
return ParsedApiKey(
438-
global_prefix=global_prefix,
439-
key_id=key_id,
440-
key_secret=key_secret,
441-
raw=api_key,
442-
)
426+
return self._get_parts(api_key)
443427

444428
async def _verify_entity(self, entity: ApiKey, key_secret: str, required_scopes: List[str]) -> ApiKey:
445429
"""Verify that an entity can authenticate with the provided secret.
@@ -458,18 +442,17 @@ async def _verify_entity(self, entity: ApiKey, key_secret: str, required_scopes:
458442
InvalidKey: If the hash does not match.
459443
InvalidScopes: If scopes are insufficient.
460444
"""
461-
assert entity.key_hash is not None, "key_hash must be set for existing API keys" # nosec B101
445+
# Todo: IDK if this line ise usefully
446+
# assert entity.key_hash is not None, "key_hash must be set for existing API keys" # nosec B101
462447

463-
entity.ensure_can_authenticate()
448+
entity.ensure_valid(scopes=required_scopes)
464449

465450
if not self._hasher.verify(entity.key_hash, key_secret):
466451
raise InvalidKey("API key is invalid (hash mismatch)")
467452

468-
entity.ensure_valid_scopes(required_scopes)
469-
470453
return await self.touch(entity)
471454

472-
def _get_parts(self, api_key: str) -> Tuple[str, str, str]:
455+
def _get_parts(self, api_key: str) -> ParsedApiKey:
473456
"""Extract the parts of the API key string.
474457
475458
Args:
@@ -486,10 +469,20 @@ def _get_parts(self, api_key: str) -> Tuple[str, str, str]:
486469
if len(parts) != 3:
487470
raise InvalidKey("API key format is invalid (wrong number of segments).")
488471

489-
if not all(parts):
472+
if not all(p.strip() for p in parts):
490473
raise InvalidKey("API key format is invalid (empty segment).")
491474

492-
return parts[0], parts[1], parts[2]
475+
parsed_api_key = ParsedApiKey(
476+
global_prefix=parts[0],
477+
key_id=parts[1],
478+
key_secret=parts[2],
479+
raw=api_key,
480+
)
481+
482+
if parsed_api_key.global_prefix != self.global_prefix:
483+
raise InvalidKey("Api key is invalid (wrong global prefix)")
484+
485+
return parsed_api_key
493486

494487
async def touch(self, entity: ApiKey) -> ApiKey:
495488
"""Update last_used_at to now and persist the change."""

src/fastapi_api_key/services/cached.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,11 @@ async def _verify_key(self, api_key: Optional[str] = None, required_scopes: Opti
112112

113113
# Compute cache key from the full API key (secure: requires complete key)
114114
cache_key = _compute_cache_key(parsed.raw)
115-
cached_entity = await self.cache.get(cache_key)
115+
cached_entity: ApiKey = await self.cache.get(cache_key)
116116

117117
if cached_entity:
118118
# Cache hit: the full API key is correct (hash matched)
119-
cached_entity.ensure_can_authenticate()
120-
cached_entity.ensure_valid_scopes(required_scopes)
119+
cached_entity.ensure_valid(scopes=required_scopes)
121120
return await self.touch(cached_entity)
122121

123122
# Cache miss: perform full verification via parent's helper

tests/unit/test_service.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,6 @@ async def test_get_by_id_success(self, service: ApiKeyService):
9292
result = await service.get_by_id(entity.id_)
9393
assert result.id_ == entity.id_
9494

95-
@pytest.mark.asyncio
96-
async def test_get_by_id_empty_raises(self, service: ApiKeyService):
97-
"""get_by_id() raises KeyNotProvided for empty ID."""
98-
with pytest.raises(KeyNotProvided):
99-
await service.get_by_id(" ")
100-
10195
@pytest.mark.asyncio
10296
async def test_get_by_id_not_found_raises(self, service: ApiKeyService):
10397
"""get_by_id() raises KeyNotFound for missing ID."""
@@ -118,12 +112,6 @@ async def test_get_by_key_id_not_found_raises(self, service: ApiKeyService):
118112
with pytest.raises(KeyNotFound):
119113
await service.get_by_key_id("missing")
120114

121-
@pytest.mark.asyncio
122-
async def test_get_by_key_id_empty_raises(self, service: ApiKeyService):
123-
"""get_by_key_id() raises KeyNotProvided for empty key_id."""
124-
with pytest.raises(KeyNotProvided):
125-
await service.get_by_key_id(" ")
126-
127115

128116
class TestServiceUpdate:
129117
"""Tests for update() method."""
@@ -217,7 +205,7 @@ async def test_verify_none_raises(self, service: ApiKeyService):
217205
@pytest.mark.asyncio
218206
async def test_verify_empty_raises(self, service: ApiKeyService):
219207
"""verify_key() raises InvalidKey for empty/whitespace string."""
220-
with pytest.raises(InvalidKey, match="wrong number of segments"):
208+
with pytest.raises(KeyNotProvided, match=r"Api key must be provided \(not given\)"):
221209
await service.verify_key(" ")
222210

223211
@pytest.mark.asyncio

0 commit comments

Comments
 (0)