|
| 1 | +#!/usr/bin/env python |
| 2 | +import openai |
| 3 | +import argparse |
| 4 | +import logging |
| 5 | +import sys |
| 6 | +from typing import List |
| 7 | + |
| 8 | +logger = logging.getLogger() |
| 9 | +formatter = logging.Formatter("[%(asctime)s] [%(process)d] %(message)s") |
| 10 | +handler = logging.StreamHandler(sys.stderr) |
| 11 | +handler.setFormatter(formatter) |
| 12 | +logger.addHandler(handler) |
| 13 | + |
| 14 | +DEFAULT_COND_LOGP_TEMPLATE = ( |
| 15 | + "<|endoftext|>{document}\n\n---\n\nThe above passage is related to: {query}" |
| 16 | +) |
| 17 | +SCORE_MULTIPLIER = 100.0 |
| 18 | + |
| 19 | + |
| 20 | +class SearchScorer: |
| 21 | + def __init__( |
| 22 | + self, *, document, query, cond_logp_template=DEFAULT_COND_LOGP_TEMPLATE |
| 23 | + ): |
| 24 | + self.document = document |
| 25 | + self.query = query |
| 26 | + self.cond_logp_template = cond_logp_template |
| 27 | + self.context = self.cond_logp_template.format( |
| 28 | + document=self.document, query=self.query |
| 29 | + ) |
| 30 | + |
| 31 | + def get_context(self): |
| 32 | + return self.context |
| 33 | + |
| 34 | + def get_score(self, choice) -> float: |
| 35 | + assert choice.text == self.context |
| 36 | + logprobs: List[float] = choice.logprobs.token_logprobs |
| 37 | + text = choice.logprobs.tokens |
| 38 | + text_len = sum(len(token) for token in text) |
| 39 | + if text_len != len(self.context): |
| 40 | + raise RuntimeError( |
| 41 | + f"text_len={text_len}, len(self.context)={len(self.context)}" |
| 42 | + ) |
| 43 | + total_len = 0 |
| 44 | + last_used = len(text) |
| 45 | + while total_len < len(self.query): |
| 46 | + assert last_used > 0 |
| 47 | + total_len += len(text[last_used - 1]) |
| 48 | + last_used -= 1 |
| 49 | + max_len = len(self.context) - self.cond_logp_template.index("{document}") |
| 50 | + assert total_len + len(self.document) <= max_len |
| 51 | + logits: List[float] = logprobs[last_used:] |
| 52 | + return sum(logits) / len(logits) * SCORE_MULTIPLIER |
| 53 | + |
| 54 | + |
| 55 | +def semantic_search(engine, query, documents): |
| 56 | + # add empty document as baseline |
| 57 | + scorers = [ |
| 58 | + SearchScorer(document=document, query=query) for document in [""] + documents |
| 59 | + ] |
| 60 | + completion = openai.Completion.create( |
| 61 | + engine=engine, |
| 62 | + prompt=[scorer.get_context() for scorer in scorers], |
| 63 | + max_tokens=0, |
| 64 | + logprobs=0, |
| 65 | + echo=True, |
| 66 | + ) |
| 67 | + # put the documents back in order so we can easily normalize by the empty document 0 |
| 68 | + data = sorted(completion.choices, key=lambda choice: choice.index) |
| 69 | + assert len(scorers) == len( |
| 70 | + data |
| 71 | + ), f"len(scorers)={len(scorers)} len(data)={len(data)}" |
| 72 | + scores = [scorer.get_score(choice) for scorer, choice in zip(scorers, data)] |
| 73 | + # subtract score for empty document |
| 74 | + scores = [score - scores[0] for score in scores][1:] |
| 75 | + data = { |
| 76 | + "object": "list", |
| 77 | + "data": [ |
| 78 | + { |
| 79 | + "object": "search_result", |
| 80 | + "document": document_idx, |
| 81 | + "score": round(score, 3), |
| 82 | + } |
| 83 | + for document_idx, score in enumerate(scores) |
| 84 | + ], |
| 85 | + "model": completion.model, |
| 86 | + } |
| 87 | + return data |
| 88 | + |
| 89 | + |
| 90 | +def main(): |
| 91 | + parser = argparse.ArgumentParser(description=None) |
| 92 | + parser.add_argument( |
| 93 | + "-v", |
| 94 | + "--verbose", |
| 95 | + action="count", |
| 96 | + dest="verbosity", |
| 97 | + default=0, |
| 98 | + help="Set verbosity.", |
| 99 | + ) |
| 100 | + parser.add_argument("-e", "--engine", default="ada") |
| 101 | + parser.add_argument("-q", "--query", required=True) |
| 102 | + parser.add_argument("-d", "--document", action="append", required=True) |
| 103 | + parser.add_argument("-s", "--server-side", action="store_true") |
| 104 | + args = parser.parse_args() |
| 105 | + |
| 106 | + if args.verbosity == 1: |
| 107 | + logger.setLevel(logging.INFO) |
| 108 | + elif args.verbosity >= 2: |
| 109 | + logger.setLevel(logging.DEBUG) |
| 110 | + |
| 111 | + if args.server_side: |
| 112 | + resp = openai.Engine(id=args.engine).search( |
| 113 | + query=args.query, documents=args.document |
| 114 | + ) |
| 115 | + resp = resp.to_dict_recursive() |
| 116 | + print(f"[server-side semantic search] {resp}") |
| 117 | + else: |
| 118 | + resp = semantic_search(args.engine, query=args.query, documents=args.document) |
| 119 | + print(f"[client-side semantic search] {resp}") |
| 120 | + |
| 121 | + return 0 |
| 122 | + |
| 123 | + |
| 124 | +if __name__ == "__main__": |
| 125 | + sys.exit(main()) |
0 commit comments