1616import argparse
1717import codecs
1818from dataclasses import dataclass
19- from datetime import datetime , timedelta , UTC
19+ from datetime import datetime , timedelta
2020import math
2121import hashlib
2222import functools
4343from osv import ecosystems
4444from osv import semver_index
4545from osv import purl_helpers
46- from osv import vulnerability_pb2
4746from osv .logs import setup_gcp_logging
4847import osv_service_v1_pb2
4948import osv_service_v1_pb2_grpc
101100# ----
102101# Type Aliases:
103102
104- ToResponseCallable = Callable [[osv .Bug ], vulnerability_pb2 . Vulnerability ]
103+ ToResponseCallable = Callable [[osv .Bug ], ndb . Future ]
105104
106105# ----
107106
@@ -163,14 +162,15 @@ class OSVServicer(osv_service_v1_pb2_grpc.OSVServicer,
163162
164163 @ndb_context
165164 @trace_filter .log_trace
165+ @ndb .synctasklet
166166 def GetVulnById (self , request , context : grpc .ServicerContext ):
167167 """Return a `Vulnerability` object for a given OSV ID."""
168- bug = osv .Bug .get_by_id ( request .id )
168+ bug = yield osv .Bug .query ( osv . Bug . db_id == request .id ). get_async ( )
169169
170170 if not bug :
171171 # Check for aliases
172- alias_group = osv .AliasGroup .query (
173- osv .AliasGroup .bug_ids == request .id ).get ()
172+ alias_group = yield osv .AliasGroup .query (
173+ osv .AliasGroup .bug_ids == request .id ).get_async ()
174174 if alias_group :
175175 alias_string = ' ' .join ([
176176 f'{ alias } ' for alias in alias_group .bug_ids if alias != request .id
@@ -190,10 +190,12 @@ def GetVulnById(self, request, context: grpc.ServicerContext):
190190 context .abort (grpc .StatusCode .PERMISSION_DENIED , 'Permission denied.' )
191191 return None
192192
193- return bug_to_response (bug , include_alias = True )
193+ resp = yield bug_to_response (bug , include_details = True )
194+ return resp
194195
195196 @ndb_context
196197 @trace_filter .log_trace
198+ @ndb .synctasklet
197199 def QueryAffected (self , request , context : grpc .ServicerContext ):
198200 """Query vulnerabilities for a particular project at a given commit or
199201
@@ -238,8 +240,7 @@ def QueryAffected(self, request, context: grpc.ServicerContext):
238240 total_responses = ResponsesCount (0 ))
239241
240242 try :
241- results , next_page_token = do_query (
242- request .query , query_context ).result () # type: ignore
243+ results , next_page_token = yield do_query (request .query , query_context )
243244 except InvalidArgument :
244245 # Currently cannot think of any other way
245246 # this can be raised other than invalid cursor
@@ -257,6 +258,7 @@ def QueryAffected(self, request, context: grpc.ServicerContext):
257258
258259 @ndb_context
259260 @trace_filter .log_trace
261+ @ndb .synctasklet
260262 def QueryAffectedBatch (self , request , context : grpc .ServicerContext ):
261263 """Query vulnerabilities (batch)."""
262264 batch_results = []
@@ -338,7 +340,7 @@ def QueryAffectedBatch(self, request, context: grpc.ServicerContext):
338340
339341 for future in futures :
340342 try :
341- result , next_page_token = future . result ()
343+ result , next_page_token = yield future
342344 except InvalidArgument :
343345 # Currently cannot think of any other way
344346 # this can be raised other than invalid cursor
@@ -356,9 +358,10 @@ def QueryAffectedBatch(self, request, context: grpc.ServicerContext):
356358
357359 @ndb_context
358360 @trace_filter .log_trace
361+ @ndb .synctasklet
359362 def DetermineVersion (self , request , context : grpc .ServicerContext ):
360363 """Determine the version of the provided hashes."""
361- res = determine_version (request .query , context ). result ( )
364+ res = yield determine_version (request .query , context )
362365 return res
363366
364367 @ndb_context
@@ -828,11 +831,9 @@ def do_query(query: osv_service_v1_pb2.Query,
828831 _MAX_VULN_LISTED_PRE_EXCEEDED_UBUNTU_EXCEPTION
829832
830833 def to_response (b : osv .Bug ):
831- # Skip retrieving aliases from to_vulnerability().
832- # Retrieve it asynchronously later.
833834 return bug_to_response (b , include_details )
834835
835- bugs : list [vulnerability_pb2 . Vulnerability ]
836+ bugs : list [ndb . Future ]
836837 if query .WhichOneof ('param' ) == 'commit' :
837838 try :
838839 commit_bytes = codecs .decode (query .commit , 'hex' )
@@ -857,31 +858,6 @@ def to_response(b: osv.Bug):
857858 # to know that control flow breaks here.
858859 raise ValueError
859860
860- # Asynchronously retrieve computed aliases and related ids here
861- # to prevent significant query time increase for packages with
862- # numerous vulnerabilities.
863- if include_details :
864- aliases = []
865- related = []
866- for bug in bugs :
867- aliases .append (osv .get_aliases_async (bug .id ))
868- related .append (osv .get_related_async (bug .id ))
869-
870- for i , alias in enumerate (aliases ):
871- alias_group : osv .AliasGroup = yield alias
872- if not alias_group :
873- continue
874- alias_ids = sorted (list (set (alias_group .bug_ids ) - {bugs [i ].id }))
875- bugs [i ].aliases [:] = alias_ids
876- modified_time = bugs [i ].modified .ToDatetime (UTC )
877- modified_time = max (alias_group .last_modified , modified_time )
878- bugs [i ].modified .FromDatetime (modified_time )
879-
880- for i , related_ids in enumerate (related ):
881- related_bug_ids : list [str ] = yield related_ids
882- bugs [i ].related [:] = sorted (
883- list (set (related_bug_ids + list (bugs [i ].related ))))
884-
885861 if context .query_counter < context .input_cursor .query_number :
886862 logging .error (
887863 'Cursor is invalid - received "%d" while total query count is "%d".' ,
@@ -895,25 +871,26 @@ def to_response(b: osv.Bug):
895871 if next_page_token_str :
896872 logging .warning ('Page size limit hit, response size: %s' , len (bugs ))
897873
898- return bugs , next_page_token_str
874+ # Wait on all the bug futures
875+ bugs = yield bugs
876+
877+ return list (bugs ), next_page_token_str
899878
900879
901- def bug_to_response (bug : osv .Bug ,
902- include_details = True ,
903- include_alias = False ) -> vulnerability_pb2 .Vulnerability :
904- """Convert a Bug entity to a response object."""
880+ def bug_to_response (bug : osv .Bug , include_details = True ) -> ndb .Future :
881+ """Asynchronously convert a Bug entity to a response object."""
905882 if include_details :
906- return bug .to_vulnerability (
907- include_source = True , include_alias = include_alias )
883+ return bug .to_vulnerability_async (
884+ include_source = True , include_alias = True , include_upstream = True )
908885
909- return bug .to_vulnerability_minimal ()
886+ return bug .to_vulnerability_minimal_async (
887+ include_alias = True , include_upstream = True )
910888
911889
912890@ndb .tasklet
913891def _get_bugs (
914892 bug_ids : list [str ],
915- to_response : ToResponseCallable = bug_to_response
916- ) -> list [vulnerability_pb2 .Vulnerability ]:
893+ to_response : ToResponseCallable = bug_to_response ) -> list [ndb .Future ]:
917894 """Get bugs from bug ids."""
918895 bugs = ndb .get_multi_async ([ndb .Key (osv .Bug , bug_id ) for bug_id in bug_ids
919896 ]) # type: ignore
@@ -931,8 +908,7 @@ def _get_bugs(
931908def query_by_commit (
932909 context : QueryContext ,
933910 commit : bytes ,
934- to_response : ToResponseCallable = bug_to_response
935- ) -> list [vulnerability_pb2 .Vulnerability ]:
911+ to_response : ToResponseCallable = bug_to_response ) -> list [ndb .Future ]:
936912 """
937913 Perform a query by commit.
938914
@@ -1190,8 +1166,7 @@ def query_by_version(
11901166 package_name : str | None ,
11911167 ecosystem : str | None ,
11921168 version : str ,
1193- to_response : ToResponseCallable = bug_to_response
1194- ) -> list [vulnerability_pb2 .Vulnerability ]:
1169+ to_response : ToResponseCallable = bug_to_response ) -> list [ndb .Future ]:
11951170 """
11961171 Query by (fuzzy) version.
11971172
@@ -1332,9 +1307,9 @@ def _query_by_comparing_versions(context: QueryContext, query: ndb.Query,
13321307
13331308
13341309@ndb .tasklet
1335- def query_by_package (
1336- context : QueryContext , package_name : str | None , ecosystem : str | None ,
1337- to_response : ToResponseCallable ) -> list [vulnerability_pb2 . Vulnerability ]:
1310+ def query_by_package (context : QueryContext , package_name : str | None ,
1311+ ecosystem : str | None ,
1312+ to_response : ToResponseCallable ) -> list [ndb . Future ]:
13381313 """
13391314 Query by package.
13401315
0 commit comments