|
| 1 | +import csv |
| 2 | +import math |
| 3 | +import argparse |
| 4 | +from algoliasearch.search.client import SearchClientSync |
| 5 | + |
| 6 | +# Initialize Algolia client |
| 7 | +ALGOLIA_APP_ID = "62VCH2MD74" |
| 8 | +ALGOLIA_API_KEY = "b78244d947484fe3ece7bc5472e9f2af" |
| 9 | +ALGOLIA_INDEX_NAME = "clickhouse" |
| 10 | + |
| 11 | +client = SearchClientSync(ALGOLIA_APP_ID, ALGOLIA_API_KEY) |
| 12 | + |
| 13 | + |
| 14 | +def compute_dcg(relevance_scores): |
| 15 | + """Compute Discounted Cumulative Gain (DCG).""" |
| 16 | + return sum(rel / math.log2(idx + 2) for idx, rel in enumerate(relevance_scores)) |
| 17 | + |
| 18 | + |
| 19 | +def compute_ndcg(expected_links, retrieved_links, k): |
| 20 | + """Compute normalized DCG.""" |
| 21 | + relevance_scores = [1 if link in expected_links else 0 for link in retrieved_links[:k]] |
| 22 | + dcg = compute_dcg(relevance_scores) |
| 23 | + |
| 24 | + ideal_relevance_scores = [1] * min(len(expected_links), k) |
| 25 | + idcg = compute_dcg(ideal_relevance_scores) |
| 26 | + |
| 27 | + return dcg / idcg if idcg > 0 else 0 |
| 28 | + |
| 29 | + |
| 30 | +def main(input_csv, detailed, k=3): |
| 31 | + """Main function to compute nDCG for search terms in a CSV.""" |
| 32 | + with open(input_csv, mode='r', newline='', encoding='utf-8') as file: |
| 33 | + reader = csv.reader(file) |
| 34 | + rows = list(reader) |
| 35 | + |
| 36 | + results = [] |
| 37 | + total_ndcg = 0 |
| 38 | + for row in rows: |
| 39 | + term = row[0] |
| 40 | + expected_links = [link for link in row[1:4] if link] # Skip empty cells |
| 41 | + |
| 42 | + # Query Algolia |
| 43 | + response = client.search( |
| 44 | + search_method_params={ |
| 45 | + "requests": [ |
| 46 | + { |
| 47 | + "indexName": ALGOLIA_INDEX_NAME, |
| 48 | + "query": term, |
| 49 | + "hitsPerPage": k, |
| 50 | + }, |
| 51 | + ], |
| 52 | + }, |
| 53 | + ) |
| 54 | + retrieved_links = [hit.url for hit in response.results[0].actual_instance.hits] |
| 55 | + |
| 56 | + # Compute nDCG |
| 57 | + ndcg = compute_ndcg(expected_links, retrieved_links, k) |
| 58 | + total_ndcg += ndcg |
| 59 | + results.append({"term": term, "nDCG": ndcg}) |
| 60 | + |
| 61 | + # Calculate Mean nDCG |
| 62 | + mean_ndcg = total_ndcg / len(rows) if rows else 0 |
| 63 | + |
| 64 | + # Display results |
| 65 | + print(f"Mean nDCG: {mean_ndcg:.4f}") |
| 66 | + if detailed: |
| 67 | + print("\nSearch Term\t\tnDCG") |
| 68 | + print("=" * 30) |
| 69 | + for result in results: |
| 70 | + print(f"{result['term']}\t\t{result['nDCG']:.4f}") |
| 71 | + |
| 72 | + |
| 73 | +if __name__ == "__main__": |
| 74 | + parser = argparse.ArgumentParser(description="Compute nDCG for Algolia search results.") |
| 75 | + parser.add_argument( |
| 76 | + "input_csv", |
| 77 | + nargs="?", |
| 78 | + default="results.csv", |
| 79 | + help="Path to the input CSV file (default: results.csv)." |
| 80 | + ) |
| 81 | + parser.add_argument( |
| 82 | + "-d", |
| 83 | + "--detailed", |
| 84 | + action="store_true", |
| 85 | + help="Print detailed results for each search term." |
| 86 | + ) |
| 87 | + args = parser.parse_args() |
| 88 | + |
| 89 | + main(args.input_csv, args.detailed) |
0 commit comments