|
| 1 | +import re |
| 2 | +import pickle |
| 3 | +from math import sqrt |
| 4 | +from collections import Counter |
| 5 | +from difflib import SequenceMatcher |
| 6 | +from collections import defaultdict |
| 7 | + |
| 8 | +PUNCTS_PAT = re.compile( |
| 9 | + r'(?:[#\$&@.,;:!?,。!?、:; \u3300\'`"~_\+\-\*\/\\|\\^=<>\[\]\(\)\{\}()“”‘’\s]|' |
| 10 | + r'[\u2000-\u206f]|' |
| 11 | + r'[\u3000-\u303f]|' |
| 12 | + r'[\uff30-\uff4f]|' |
| 13 | + r'[\uff00-\uff0f\uff1a-\uff20\uff3b-\uff40\uff5b-\uff65])+' |
| 14 | +) |
| 15 | + |
| 16 | + |
| 17 | +def make_terms(text, term, ngram_range=None, lower=True, ignore_punct=True, gram_as_tuple=False): |
| 18 | + if lower: |
| 19 | + text = text.lower() |
| 20 | + if term == 'word': |
| 21 | + # term_seq = [word.strip() for word in jieba.cut(text) if word.strip()] |
| 22 | + term_seq = [word.strip() for word in text.split() if word.strip()] |
| 23 | + elif term == 'char': |
| 24 | + term_seq = list(re.sub(r'\s', '', text)) |
| 25 | + else: |
| 26 | + raise ValueError("unsupported term type: {}".foramt(term)) |
| 27 | + |
| 28 | + if ngram_range and not (len(ngram_range) == 2 and ngram_range[0] < ngram_range[1]): |
| 29 | + raise ValueError("wrong `ngram_range`: {}".foramt(ngram_range)) |
| 30 | + |
| 31 | + terms = [] |
| 32 | + min_ngram, max_ngram = ngram_range or (1, 2) |
| 33 | + for idx in range(0, max(1, len(term_seq) - min_ngram + 1)): |
| 34 | + cur_grams = [] |
| 35 | + for gram_level in range(min_ngram, max_ngram): |
| 36 | + if gram_as_tuple: |
| 37 | + gram = tuple(term_seq[idx:idx + gram_level]) |
| 38 | + else: |
| 39 | + gram = ''.join(term_seq[idx:idx + gram_level]) |
| 40 | + if gram not in cur_grams: |
| 41 | + if ignore_punct and any(PUNCTS_PAT.match(item) for item in gram): |
| 42 | + pass |
| 43 | + else: |
| 44 | + cur_grams.append(gram) |
| 45 | + terms.extend(cur_grams) |
| 46 | + return terms |
| 47 | + |
| 48 | + |
| 49 | +def lcs_sim(s1, s2, term='char', ngram_range=None, ngram_weights=None, |
| 50 | + lower=True, ignore_punct=True): |
| 51 | + s1_terms = make_terms(s1, 'char', None, lower, ignore_punct) |
| 52 | + s2_terms = make_terms(s2, 'char', None, lower, ignore_punct) |
| 53 | + return SequenceMatcher(a=s1_terms, b=s2_terms).ratio() |
| 54 | + |
| 55 | + |
| 56 | +def jaccard_sim(s1, s2, term='word', ngram_range=None, ngram_weights=None, |
| 57 | + lower=True, ignore_punct=True): |
| 58 | + if not ngram_range or ngram_range[1] == ngram_range[0] + 1: |
| 59 | + first_term_set = set(make_terms(s1, term, ngram_range, lower, ignore_punct)) |
| 60 | + second_term_set = set(make_terms(s2, term, ngram_range, lower, ignore_punct)) |
| 61 | + if not first_term_set and not second_term_set: |
| 62 | + return 1.0 |
| 63 | + return len(first_term_set & second_term_set) / len(first_term_set | second_term_set) |
| 64 | + else: |
| 65 | + weights = ngram_weights or list(range(*ngram_range)) |
| 66 | + weights_sum = sum(weights) |
| 67 | + weights = [weight / weights_sum for weight in weights] |
| 68 | + scores = [] |
| 69 | + for ngram_level in range(*ngram_range): |
| 70 | + score = jaccard_sim(s1, s2, term=term, |
| 71 | + ngram_range=(ngram_level, ngram_level + 1), |
| 72 | + lower=lower, ignore_punct=ignore_punct) |
| 73 | + scores.append(score) |
| 74 | + |
| 75 | + return sum([score * weight for score, weight in zip(scores, weights)]) |
| 76 | + |
| 77 | + |
| 78 | +def cosine_sim(s1, s2, term='word', ngram_range=None, ngram_weights=None, |
| 79 | + lower=True, ignore_punct=True): |
| 80 | + if not ngram_range or ngram_range[1] == ngram_range[0] + 1: |
| 81 | + first_term_freq = Counter(make_terms(s1, term, ngram_range, lower, ignore_punct)) |
| 82 | + second_term_freq = Counter(make_terms(s2, term, ngram_range, lower, ignore_punct)) |
| 83 | + |
| 84 | + first_norm = 0 |
| 85 | + second_norm = 0 |
| 86 | + inner_product = 0 |
| 87 | + |
| 88 | + for term, freq in first_term_freq.items(): |
| 89 | + first_norm += freq ** 2 |
| 90 | + inner_product += freq * second_term_freq[term] |
| 91 | + |
| 92 | + for term, freq in second_term_freq.items(): |
| 93 | + second_norm += freq ** 2 |
| 94 | + |
| 95 | + if first_norm == 0 and second_norm == 0: |
| 96 | + return 1.0 |
| 97 | + if first_norm == 0 or second_norm == 0: |
| 98 | + return 0.0 |
| 99 | + |
| 100 | + return inner_product / sqrt(first_norm * second_norm) |
| 101 | + else: |
| 102 | + weights = ngram_weights or list(range(*ngram_range)) |
| 103 | + weights_sum = sum(weights) |
| 104 | + weights = [weight / weights_sum for weight in weights] |
| 105 | + scores = [] |
| 106 | + for ngram_level in range(*ngram_range): |
| 107 | + score = cosine_sim(s1, s2, term=term, |
| 108 | + ngram_range=(ngram_level, ngram_level + 1), |
| 109 | + lower=lower, ignore_punct=ignore_punct) |
| 110 | + scores.append(score) |
| 111 | + |
| 112 | + return sum([score * weight for score, weight in zip(scores, weights)]) |
| 113 | + |
| 114 | + |
| 115 | +def sim_of(s1, s2, method='cosine', term='word', ngram_range=None, lower=True, ignore_punct=True): |
| 116 | + method_func = { |
| 117 | + 'lcs': lcs_sim, |
| 118 | + 'jaccard': jaccard_sim, |
| 119 | + 'cosine': cosine_sim, |
| 120 | + }.get(method) |
| 121 | + if not method_func: |
| 122 | + raise ValueError("unsupported method: {}".format(method)) |
| 123 | + |
| 124 | + return method_func(s1, s2, term=term, ngram_range=ngram_range, |
| 125 | + lower=lower, ignore_punct=ignore_punct) |
| 126 | + |
| 127 | + |
| 128 | +class InvIndex(object): |
| 129 | + def __init__(self): |
| 130 | + """build inverted index with ngram method""" |
| 131 | + self._id2doc = {} |
| 132 | + self._index = defaultdict(set) |
| 133 | + |
| 134 | + def add_doc(self, doc): |
| 135 | + if doc.id in self._id2doc: |
| 136 | + return False |
| 137 | + |
| 138 | + self._id2doc[doc.id] = doc.title |
| 139 | + terms = set(make_terms(doc.title, 'char', (3, 4))) |
| 140 | + for term in terms: |
| 141 | + self._index[term].add(doc.id) |
| 142 | + |
| 143 | + return True |
| 144 | + |
| 145 | + def retrieve(self, query, k=10): |
| 146 | + related = Counter() |
| 147 | + terms = set(make_terms(query, 'char', (3, 4))) |
| 148 | + for term in terms: |
| 149 | + for qid in self._index.get(term, []): |
| 150 | + related[qid] += 1 |
| 151 | + |
| 152 | + return [(idx, self._id2doc[idx], score) for idx, score in related.most_common(k)] |
| 153 | + |
| 154 | + def save(self, fname): |
| 155 | + pickle.dump((self._id2doc, self._index), open(fname, 'wb')) |
| 156 | + |
| 157 | + def load(self, fname): |
| 158 | + self._id2doc, self._index = pickle.load(open(fname, 'rb')) |
0 commit comments