Skip to content

Commit 513f9c4

Browse files
authored
feat: use corpus from cache for typo checking (#17812)
* fix: handle missing keys in cache Don't try to deserialize is nothing came back. Signed-off-by: Mike Fiedler <[email protected]> * feat: use corpus from cache for typo checking Signed-off-by: Mike Fiedler <[email protected]> --------- Signed-off-by: Mike Fiedler <[email protected]>
1 parent 7808bbf commit 513f9c4

File tree

5 files changed

+27
-6
lines changed

5 files changed

+27
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ ignore_missing_imports = true
8282
[tool.pytest.ini_options]
8383
addopts = [
8484
"--disable-socket",
85-
"--allow-hosts=localhost,::1,stripe",
85+
"--allow-hosts=localhost,::1,stripe,redis",
8686
"--durations=20",
8787
"--numprocesses=auto",
8888
# Disable ddtrace for tests

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def get_app_config(database, nondefaults=None):
320320
"database.url": database,
321321
"docs.url": "http://docs.example.com/",
322322
"ratelimit.url": "memory://",
323-
"db_results_cache.url": "redis://localhost:0/",
323+
"db_results_cache.url": "redis://redis:0/",
324324
"opensearch.url": "https://localhost/warehouse",
325325
"files.backend": "warehouse.packaging.services.LocalFileStorage",
326326
"archive_files.backend": "warehouse.packaging.services.LocalArchiveFileStorage",

tests/unit/cache/test_services.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ def test_create_service(self):
3434

3535
assert isinstance(service, RedisQueryResults)
3636

37+
def test_get_missing(self, query_results_cache_service):
38+
# Attempt to get a value that doesn't exist in the cache
39+
result = query_results_cache_service.get("missing_key")
40+
41+
assert result is None
42+
3743
def test_set_get_simple(self, query_results_cache_service):
3844
# Set a value in the cache
3945
query_results_cache_service.set("test_key", {"foo": "bar"})

warehouse/cache/services.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def get(self, key: str) -> list | dict | None:
4949
"""Get a cached result by key."""
5050
result = self.redis_client.get(key)
5151
# deserialize the value as a JSON object
52-
return orjson.loads(result)
52+
return orjson.loads(result) if result else None
5353

5454
def set(self, key: str, value) -> None:
5555
"""Set a cached result by key."""

warehouse/packaging/services.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from zope.interface import implementer
3535

3636
from warehouse.admin.flags import AdminFlagValue
37+
from warehouse.cache import IQueryResultsCache
3738
from warehouse.email import send_pending_trusted_publisher_invalidated_email
3839
from warehouse.events.tags import EventTag
3940
from warehouse.helpdesk.interfaces import IAdminNotificationService
@@ -411,13 +412,18 @@ def create_service(cls, context, request):
411412

412413
@implementer(IProjectService)
413414
class ProjectService:
414-
def __init__(self, session, metrics=None, ratelimiters=None) -> None:
415+
def __init__(
416+
self, session, metrics=None, ratelimiters=None, query_results_cache=None
417+
) -> None:
415418
if ratelimiters is None:
416419
ratelimiters = {}
420+
if query_results_cache is None:
421+
query_results_cache = {}
417422

418423
self.db = session
419424
self.ratelimiters = collections.defaultdict(DummyRateLimiter, ratelimiters)
420425
self._metrics = metrics
426+
self._query_results_cache = query_results_cache
421427

422428
def _check_ratelimits(self, request, creator):
423429
# First we want to check if a single IP is exceeding our rate limiter.
@@ -486,7 +492,10 @@ def check_project_name(self, name: str) -> None:
486492
raise ProjectNameUnavailableSimilarError(similar_project_name)
487493

488494
# Check for typo-squatting.
489-
if typo_check_match := typo_check_name(canonicalize_name(name)):
495+
cached_corpus = self._query_results_cache.get("top_dependents_corpus")
496+
if typo_check_match := typo_check_name(
497+
canonicalize_name(name), corpus=cached_corpus
498+
):
490499
raise ProjectNameUnavailableTypoSquattingError(
491500
check_name=typo_check_match[0],
492501
existing_project_name=typo_check_match[1],
@@ -718,4 +727,10 @@ def project_service_factory(context, request):
718727
IRateLimiter, name="project.create.ip", context=None
719728
),
720729
}
721-
return ProjectService(request.db, metrics=metrics, ratelimiters=ratelimiters)
730+
query_results_cache = request.find_service(IQueryResultsCache)
731+
return ProjectService(
732+
request.db,
733+
metrics=metrics,
734+
ratelimiters=ratelimiters,
735+
query_results_cache=query_results_cache,
736+
)

0 commit comments

Comments
 (0)