@@ -245,26 +245,62 @@ def is_solution(self, problem: dict[str, Any], solution: dict[str, Any]) -> bool
245245 logging .error ("Negative distances found." )
246246 return False
247247
248- sample_size = min (10 , n_queries )
249- for i in np .random .choice (n_queries , sample_size , replace = False ):
250- for j in range (min (3 , k )):
251- idx = indices [i , j ]
252- computed = np .sum ((queries [i ] - points [idx ]) ** 2 )
253- if not np .isclose (computed , distances [i , j ], rtol = 1e-4 , atol = 1e-4 ):
248+ # Deterministic sampling keeps validation stable across runs while scaling to large inputs.
249+ rng = np .random .default_rng (0 )
250+ sample_size = min (64 , n_queries )
251+ if sample_size == n_queries :
252+ sample_idx = np .arange (n_queries , dtype = int )
253+ else :
254+ sample_idx = rng .choice (n_queries , sample_size , replace = False )
255+
256+ # Reject degenerate rows with repeated neighbor ids (common shortcut pattern).
257+ if k > 1 and n_queries > 0 :
258+ sorted_idx = np .sort (indices , axis = 1 )
259+ dup_rows = np .where (np .any (np .diff (sorted_idx , axis = 1 ) == 0 , axis = 1 ))[0 ]
260+ if dup_rows .size > 0 :
261+ logging .error (f"Duplicate neighbor indices detected for query { int (dup_rows [0 ])} ." )
262+ return False
263+
264+ # Validate reported squared distances for all neighbors on sampled queries.
265+ for i in sample_idx :
266+ row_indices = indices [i ]
267+ computed = np .sum ((points [row_indices ] - queries [i ]) ** 2 , axis = 1 )
268+ if not np .allclose (computed , distances [i ], rtol = 1e-4 , atol = 1e-4 ):
269+ max_abs = float (np .max (np .abs (computed - distances [i ])))
270+ logging .error (
271+ f"Distance mismatch for query { int (i )} . Max absolute error: { max_abs :.6g} "
272+ )
273+ return False
274+
275+ # Distances must be sorted in ascending order for every query row.
276+ if n_queries > 0 and k > 1 :
277+ unsorted_rows = np .where (np .any (np .diff (distances , axis = 1 ) < - 1e-5 , axis = 1 ))[0 ]
278+ if unsorted_rows .size > 0 :
279+ logging .error (f"Distances not sorted ascending for query { int (unsorted_rows [0 ])} ." )
280+ return False
281+
282+ # Require nearest-neighbor recall for sampled queries across all dimensions.
283+ if k > 0 and n_queries > 0 :
284+ min_recall = 0.95 if dim <= 10 else max (0.3 , 1.0 - (dim / 200.0 ))
285+ for idx in sample_idx :
286+ all_dist2 = np .sum ((points - queries [idx ]) ** 2 , axis = 1 )
287+ true_idx = np .argsort (all_dist2 )[:k ]
288+ recall = len (np .intersect1d (indices [idx ], true_idx )) / float (k )
289+ if recall < min_recall :
254290 logging .error (
255- f"Distance mismatch for query { i } , neighbor { j } . "
256- f"Computed: { computed } , Provided: { distances [i , j ]} "
291+ f"Recall { recall :.2f} below { min_recall :.2f} for query { int (idx )} (dim={ dim } )."
257292 )
258293 return False
259- for i in range (n_queries ):
260- if not np .all (np .diff (distances [i ]) >= - 1e-5 ):
261- logging .error (f"Distances not sorted ascending for query { i } ." )
262- return False
263294
264295 # ======== Boundary case handling ========
265296 if problem .get ("distribution" ) == "hypercube_shell" :
266297 pts = points .astype (np .float64 )
267298 bq_idx = np .array (solution .get ("boundary_indices" , []))
299+ if bq_idx .shape != (2 * dim , k ):
300+ logging .error (
301+ f"Boundary indices shape incorrect. Expected { (2 * dim , k )} , got { bq_idx .shape } ."
302+ )
303+ return False
268304 bqs = []
269305 for d in range (dim ):
270306 q0 = np .zeros (dim , dtype = np .float64 )
@@ -289,18 +325,4 @@ def is_solution(self, problem: dict[str, Any], solution: dict[str, Any]) -> bool
289325 "Skipping boundary checks for distribution=%s" , problem .get ("distribution" )
290326 )
291327
292- # ======== High-dimensional correctness ========
293- if dim > 10 :
294- sample = min (5 , n_queries )
295- for idx in np .random .choice (n_queries , sample , replace = False ):
296- all_dist = np .sqrt (np .sum ((points - queries [idx ]) ** 2 , axis = 1 ))
297- true_idx = np .argsort (all_dist )[:k ]
298- recall = len (np .intersect1d (indices [idx ], true_idx )) / float (k )
299- min_acc = max (0.3 , 1.0 - (dim / 200.0 ))
300- if recall < min_acc :
301- logging .error (
302- f"High-dimensional recall { recall :.2f} below { min_acc :.2f} for dim { dim } "
303- )
304- return False
305-
306328 return True
0 commit comments