Skip to content

Commit 5d0cb90

Browse files
committed
Harden kd_tree is_solution checks
1 parent bff4f2a commit 5d0cb90

File tree

1 file changed

+48
-26
lines changed

1 file changed

+48
-26
lines changed

AlgoTuneTasks/kd_tree/kd_tree.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -245,26 +245,62 @@ def is_solution(self, problem: dict[str, Any], solution: dict[str, Any]) -> bool
245245
logging.error("Negative distances found.")
246246
return False
247247

248-
sample_size = min(10, n_queries)
249-
for i in np.random.choice(n_queries, sample_size, replace=False):
250-
for j in range(min(3, k)):
251-
idx = indices[i, j]
252-
computed = np.sum((queries[i] - points[idx]) ** 2)
253-
if not np.isclose(computed, distances[i, j], rtol=1e-4, atol=1e-4):
248+
# Deterministic sampling keeps validation stable across runs while scaling to large inputs.
249+
rng = np.random.default_rng(0)
250+
sample_size = min(64, n_queries)
251+
if sample_size == n_queries:
252+
sample_idx = np.arange(n_queries, dtype=int)
253+
else:
254+
sample_idx = rng.choice(n_queries, sample_size, replace=False)
255+
256+
# Reject degenerate rows with repeated neighbor ids (common shortcut pattern).
257+
if k > 1 and n_queries > 0:
258+
sorted_idx = np.sort(indices, axis=1)
259+
dup_rows = np.where(np.any(np.diff(sorted_idx, axis=1) == 0, axis=1))[0]
260+
if dup_rows.size > 0:
261+
logging.error(f"Duplicate neighbor indices detected for query {int(dup_rows[0])}.")
262+
return False
263+
264+
# Validate reported squared distances for all neighbors on sampled queries.
265+
for i in sample_idx:
266+
row_indices = indices[i]
267+
computed = np.sum((points[row_indices] - queries[i]) ** 2, axis=1)
268+
if not np.allclose(computed, distances[i], rtol=1e-4, atol=1e-4):
269+
max_abs = float(np.max(np.abs(computed - distances[i])))
270+
logging.error(
271+
f"Distance mismatch for query {int(i)}. Max absolute error: {max_abs:.6g}"
272+
)
273+
return False
274+
275+
# Distances must be sorted in ascending order for every query row.
276+
if n_queries > 0 and k > 1:
277+
unsorted_rows = np.where(np.any(np.diff(distances, axis=1) < -1e-5, axis=1))[0]
278+
if unsorted_rows.size > 0:
279+
logging.error(f"Distances not sorted ascending for query {int(unsorted_rows[0])}.")
280+
return False
281+
282+
# Require nearest-neighbor recall for sampled queries across all dimensions.
283+
if k > 0 and n_queries > 0:
284+
min_recall = 0.95 if dim <= 10 else max(0.3, 1.0 - (dim / 200.0))
285+
for idx in sample_idx:
286+
all_dist2 = np.sum((points - queries[idx]) ** 2, axis=1)
287+
true_idx = np.argsort(all_dist2)[:k]
288+
recall = len(np.intersect1d(indices[idx], true_idx)) / float(k)
289+
if recall < min_recall:
254290
logging.error(
255-
f"Distance mismatch for query {i}, neighbor {j}. "
256-
f"Computed: {computed}, Provided: {distances[i, j]}"
291+
f"Recall {recall:.2f} below {min_recall:.2f} for query {int(idx)} (dim={dim})."
257292
)
258293
return False
259-
for i in range(n_queries):
260-
if not np.all(np.diff(distances[i]) >= -1e-5):
261-
logging.error(f"Distances not sorted ascending for query {i}.")
262-
return False
263294

264295
# ======== Boundary case handling ========
265296
if problem.get("distribution") == "hypercube_shell":
266297
pts = points.astype(np.float64)
267298
bq_idx = np.array(solution.get("boundary_indices", []))
299+
if bq_idx.shape != (2 * dim, k):
300+
logging.error(
301+
f"Boundary indices shape incorrect. Expected {(2 * dim, k)}, got {bq_idx.shape}."
302+
)
303+
return False
268304
bqs = []
269305
for d in range(dim):
270306
q0 = np.zeros(dim, dtype=np.float64)
@@ -289,18 +325,4 @@ def is_solution(self, problem: dict[str, Any], solution: dict[str, Any]) -> bool
289325
"Skipping boundary checks for distribution=%s", problem.get("distribution")
290326
)
291327

292-
# ======== High-dimensional correctness ========
293-
if dim > 10:
294-
sample = min(5, n_queries)
295-
for idx in np.random.choice(n_queries, sample, replace=False):
296-
all_dist = np.sqrt(np.sum((points - queries[idx]) ** 2, axis=1))
297-
true_idx = np.argsort(all_dist)[:k]
298-
recall = len(np.intersect1d(indices[idx], true_idx)) / float(k)
299-
min_acc = max(0.3, 1.0 - (dim / 200.0))
300-
if recall < min_acc:
301-
logging.error(
302-
f"High-dimensional recall {recall:.2f} below {min_acc:.2f} for dim {dim}"
303-
)
304-
return False
305-
306328
return True

0 commit comments

Comments
 (0)