Skip to content

Commit 1f7c419

Browse files
authored
perf(api): Improve asynchronous behaviour (google#3638)
Change all(?) the ndb calls used throughout the API handlers to use the async versions, and rewrote a few functions to not block on the completion of futures that it doesn't need to (particularly, converting the Bug to a proto, and populating the related/alias/upstream fields of those protos). Testing on my private instance, this should be a 3-5x speedup for queries with lots of vulnerabilities. I've also made the batch queries properly set the modified time from the alias/upstream group (which might actually be a tiny performance loss) to ensure the modified time is always the same for the batch query and the get by ID. I've added `@ndb.synctasklet` to the rpc handlers to let us use `yield` instead of `.result()`, just to discourage usage of it (using `.result()` in a tasklet can cause stack overflows)
1 parent e536895 commit 1f7c419

File tree

2 files changed

+88
-64
lines changed

2 files changed

+88
-64
lines changed

gcp/api/server.py

Lines changed: 31 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import argparse
1717
import codecs
1818
from dataclasses import dataclass
19-
from datetime import datetime, timedelta, UTC
19+
from datetime import datetime, timedelta
2020
import math
2121
import hashlib
2222
import functools
@@ -43,7 +43,6 @@
4343
from osv import ecosystems
4444
from osv import semver_index
4545
from osv import purl_helpers
46-
from osv import vulnerability_pb2
4746
from osv.logs import setup_gcp_logging
4847
import osv_service_v1_pb2
4948
import osv_service_v1_pb2_grpc
@@ -101,7 +100,7 @@
101100
# ----
102101
# Type Aliases:
103102

104-
ToResponseCallable = Callable[[osv.Bug], vulnerability_pb2.Vulnerability]
103+
ToResponseCallable = Callable[[osv.Bug], ndb.Future]
105104

106105
# ----
107106

@@ -163,14 +162,15 @@ class OSVServicer(osv_service_v1_pb2_grpc.OSVServicer,
163162

164163
@ndb_context
165164
@trace_filter.log_trace
165+
@ndb.synctasklet
166166
def GetVulnById(self, request, context: grpc.ServicerContext):
167167
"""Return a `Vulnerability` object for a given OSV ID."""
168-
bug = osv.Bug.get_by_id(request.id)
168+
bug = yield osv.Bug.query(osv.Bug.db_id == request.id).get_async()
169169

170170
if not bug:
171171
# Check for aliases
172-
alias_group = osv.AliasGroup.query(
173-
osv.AliasGroup.bug_ids == request.id).get()
172+
alias_group = yield osv.AliasGroup.query(
173+
osv.AliasGroup.bug_ids == request.id).get_async()
174174
if alias_group:
175175
alias_string = ' '.join([
176176
f'{alias}' for alias in alias_group.bug_ids if alias != request.id
@@ -190,10 +190,12 @@ def GetVulnById(self, request, context: grpc.ServicerContext):
190190
context.abort(grpc.StatusCode.PERMISSION_DENIED, 'Permission denied.')
191191
return None
192192

193-
return bug_to_response(bug, include_alias=True)
193+
resp = yield bug_to_response(bug, include_details=True)
194+
return resp
194195

195196
@ndb_context
196197
@trace_filter.log_trace
198+
@ndb.synctasklet
197199
def QueryAffected(self, request, context: grpc.ServicerContext):
198200
"""Query vulnerabilities for a particular project at a given commit or
199201
@@ -238,8 +240,7 @@ def QueryAffected(self, request, context: grpc.ServicerContext):
238240
total_responses=ResponsesCount(0))
239241

240242
try:
241-
results, next_page_token = do_query(
242-
request.query, query_context).result() # type: ignore
243+
results, next_page_token = yield do_query(request.query, query_context)
243244
except InvalidArgument:
244245
# Currently cannot think of any other way
245246
# this can be raised other than invalid cursor
@@ -257,6 +258,7 @@ def QueryAffected(self, request, context: grpc.ServicerContext):
257258

258259
@ndb_context
259260
@trace_filter.log_trace
261+
@ndb.synctasklet
260262
def QueryAffectedBatch(self, request, context: grpc.ServicerContext):
261263
"""Query vulnerabilities (batch)."""
262264
batch_results = []
@@ -338,7 +340,7 @@ def QueryAffectedBatch(self, request, context: grpc.ServicerContext):
338340

339341
for future in futures:
340342
try:
341-
result, next_page_token = future.result()
343+
result, next_page_token = yield future
342344
except InvalidArgument:
343345
# Currently cannot think of any other way
344346
# this can be raised other than invalid cursor
@@ -356,9 +358,10 @@ def QueryAffectedBatch(self, request, context: grpc.ServicerContext):
356358

357359
@ndb_context
358360
@trace_filter.log_trace
361+
@ndb.synctasklet
359362
def DetermineVersion(self, request, context: grpc.ServicerContext):
360363
"""Determine the version of the provided hashes."""
361-
res = determine_version(request.query, context).result()
364+
res = yield determine_version(request.query, context)
362365
return res
363366

364367
@ndb_context
@@ -828,11 +831,9 @@ def do_query(query: osv_service_v1_pb2.Query,
828831
_MAX_VULN_LISTED_PRE_EXCEEDED_UBUNTU_EXCEPTION
829832

830833
def to_response(b: osv.Bug):
831-
# Skip retrieving aliases from to_vulnerability().
832-
# Retrieve it asynchronously later.
833834
return bug_to_response(b, include_details)
834835

835-
bugs: list[vulnerability_pb2.Vulnerability]
836+
bugs: list[ndb.Future]
836837
if query.WhichOneof('param') == 'commit':
837838
try:
838839
commit_bytes = codecs.decode(query.commit, 'hex')
@@ -857,31 +858,6 @@ def to_response(b: osv.Bug):
857858
# to know that control flow breaks here.
858859
raise ValueError
859860

860-
# Asynchronously retrieve computed aliases and related ids here
861-
# to prevent significant query time increase for packages with
862-
# numerous vulnerabilities.
863-
if include_details:
864-
aliases = []
865-
related = []
866-
for bug in bugs:
867-
aliases.append(osv.get_aliases_async(bug.id))
868-
related.append(osv.get_related_async(bug.id))
869-
870-
for i, alias in enumerate(aliases):
871-
alias_group: osv.AliasGroup = yield alias
872-
if not alias_group:
873-
continue
874-
alias_ids = sorted(list(set(alias_group.bug_ids) - {bugs[i].id}))
875-
bugs[i].aliases[:] = alias_ids
876-
modified_time = bugs[i].modified.ToDatetime(UTC)
877-
modified_time = max(alias_group.last_modified, modified_time)
878-
bugs[i].modified.FromDatetime(modified_time)
879-
880-
for i, related_ids in enumerate(related):
881-
related_bug_ids: list[str] = yield related_ids
882-
bugs[i].related[:] = sorted(
883-
list(set(related_bug_ids + list(bugs[i].related))))
884-
885861
if context.query_counter < context.input_cursor.query_number:
886862
logging.error(
887863
'Cursor is invalid - received "%d" while total query count is "%d".',
@@ -895,25 +871,26 @@ def to_response(b: osv.Bug):
895871
if next_page_token_str:
896872
logging.warning('Page size limit hit, response size: %s', len(bugs))
897873

898-
return bugs, next_page_token_str
874+
# Wait on all the bug futures
875+
bugs = yield bugs
876+
877+
return list(bugs), next_page_token_str
899878

900879

901-
def bug_to_response(bug: osv.Bug,
902-
include_details=True,
903-
include_alias=False) -> vulnerability_pb2.Vulnerability:
904-
"""Convert a Bug entity to a response object."""
880+
def bug_to_response(bug: osv.Bug, include_details=True) -> ndb.Future:
881+
"""Asynchronously convert a Bug entity to a response object."""
905882
if include_details:
906-
return bug.to_vulnerability(
907-
include_source=True, include_alias=include_alias)
883+
return bug.to_vulnerability_async(
884+
include_source=True, include_alias=True, include_upstream=True)
908885

909-
return bug.to_vulnerability_minimal()
886+
return bug.to_vulnerability_minimal_async(
887+
include_alias=True, include_upstream=True)
910888

911889

912890
@ndb.tasklet
913891
def _get_bugs(
914892
bug_ids: list[str],
915-
to_response: ToResponseCallable = bug_to_response
916-
) -> list[vulnerability_pb2.Vulnerability]:
893+
to_response: ToResponseCallable = bug_to_response) -> list[ndb.Future]:
917894
"""Get bugs from bug ids."""
918895
bugs = ndb.get_multi_async([ndb.Key(osv.Bug, bug_id) for bug_id in bug_ids
919896
]) # type: ignore
@@ -931,8 +908,7 @@ def _get_bugs(
931908
def query_by_commit(
932909
context: QueryContext,
933910
commit: bytes,
934-
to_response: ToResponseCallable = bug_to_response
935-
) -> list[vulnerability_pb2.Vulnerability]:
911+
to_response: ToResponseCallable = bug_to_response) -> list[ndb.Future]:
936912
"""
937913
Perform a query by commit.
938914
@@ -1190,8 +1166,7 @@ def query_by_version(
11901166
package_name: str | None,
11911167
ecosystem: str | None,
11921168
version: str,
1193-
to_response: ToResponseCallable = bug_to_response
1194-
) -> list[vulnerability_pb2.Vulnerability]:
1169+
to_response: ToResponseCallable = bug_to_response) -> list[ndb.Future]:
11951170
"""
11961171
Query by (fuzzy) version.
11971172
@@ -1332,9 +1307,9 @@ def _query_by_comparing_versions(context: QueryContext, query: ndb.Query,
13321307

13331308

13341309
@ndb.tasklet
1335-
def query_by_package(
1336-
context: QueryContext, package_name: str | None, ecosystem: str | None,
1337-
to_response: ToResponseCallable) -> list[vulnerability_pb2.Vulnerability]:
1310+
def query_by_package(context: QueryContext, package_name: str | None,
1311+
ecosystem: str | None,
1312+
to_response: ToResponseCallable) -> list[ndb.Future]:
13381313
"""
13391314
Query by package.
13401315

osv/models.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,37 @@ def to_vulnerability_minimal(self):
656656

657657
return vulnerability_pb2.Vulnerability(id=self.id(), modified=modified)
658658

659+
@ndb.tasklet
660+
def to_vulnerability_minimal_async(self,
661+
include_alias=True,
662+
include_upstream=True):
663+
"""Convert to Vulnerability proto (minimal) asynchronously."""
664+
modified_times = []
665+
if self.last_modified:
666+
modified_times.append(self.last_modified)
667+
668+
# Fetch the last_modified dates from the upstream/alias groups.
669+
alias_future = get_aliases_async(self.id()) if include_alias else None
670+
upstream_future = (
671+
get_upstream_async(self.id()) if include_upstream else None)
672+
673+
if include_alias:
674+
alias = yield alias_future
675+
if alias and alias.last_modified:
676+
modified_times.append(alias.last_modified)
677+
678+
if include_upstream:
679+
upstream = yield upstream_future
680+
if upstream and upstream.last_modified:
681+
modified_times.append(upstream.last_modified)
682+
683+
modified = None
684+
if modified_times:
685+
modified = timestamp_pb2.Timestamp()
686+
modified.FromDatetime(max(modified_times))
687+
688+
return vulnerability_pb2.Vulnerability(id=self.id(), modified=modified)
689+
659690
def to_vulnerability(self,
660691
include_source=False,
661692
include_alias=True,
@@ -809,17 +840,35 @@ def to_vulnerability_async(self,
809840
include_alias=False,
810841
include_upstream=False):
811842
"""Converts to Vulnerability proto and retrieves aliases asynchronously."""
843+
# Convert the vulnerability without any subqueries first.
812844
vulnerability: vulnerability_pb2.Vulnerability = self.to_vulnerability(
813-
include_source=include_source,
814-
include_alias=False,
815-
include_upstream=False)
845+
include_source=False, include_alias=False, include_upstream=False)
816846

817-
related_bug_ids = yield get_related_async(vulnerability.id)
818-
vulnerability.related[:] = sorted(
819-
list(set(related_bug_ids + list(vulnerability.related))))
847+
# Asynchronously make all necessary subqueries.
848+
if not self.source:
849+
include_source = False
850+
source_future = (
851+
SourceRepository.get_by_id_async(self.source)
852+
if include_source else None)
853+
related_future = (
854+
get_related_async(vulnerability.id) if include_alias else None)
855+
alias_future = (
856+
get_aliases_async(vulnerability.id) if include_alias else None)
857+
upstream_future = (
858+
get_upstream_async(vulnerability.id) if include_upstream else None)
859+
860+
if include_source:
861+
source_repo = yield source_future
862+
if source_repo and source_repo.link:
863+
source_link = source_repo.link + sources.source_path(source_repo, self)
864+
for affected in vulnerability.affected:
865+
affected.database_specific.update({'source': source_link})
820866

821867
if include_alias:
822-
alias_group = yield get_aliases_async(vulnerability.id)
868+
related_bug_ids = yield related_future
869+
vulnerability.related[:] = sorted(
870+
list(set(related_bug_ids + list(vulnerability.related))))
871+
alias_group = yield alias_future
823872
if alias_group:
824873
alias_ids = sorted(list(set(alias_group.bug_ids) - {vulnerability.id}))
825874
vulnerability.aliases[:] = alias_ids
@@ -828,7 +877,7 @@ def to_vulnerability_async(self,
828877
vulnerability.modified.FromDatetime(modified_time)
829878

830879
if include_upstream:
831-
upstream_group = yield get_upstream_async(vulnerability.id)
880+
upstream_group = yield upstream_future
832881
if upstream_group:
833882
vulnerability.upstream[:] = upstream_group.upstream_ids
834883
modified_time = vulnerability.modified.ToDatetime(datetime.UTC)

0 commit comments

Comments
 (0)