|
1 | 1 | import copy |
2 | 2 | import enum |
| 3 | +import itertools |
3 | 4 | import logging |
| 5 | +import re |
4 | 6 | import threading |
5 | 7 | import urllib |
6 | 8 | from collections import defaultdict |
7 | 9 | from dataclasses import dataclass, field |
8 | | -from typing import DefaultDict, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union |
| 10 | +from typing import ( |
| 11 | + DefaultDict, |
| 12 | + Dict, |
| 13 | + Iterable, |
| 14 | + List, |
| 15 | + NamedTuple, |
| 16 | + Optional, |
| 17 | + Sequence, |
| 18 | + Tuple, |
| 19 | + Union, |
| 20 | +) |
9 | 21 |
|
10 | 22 | import requests.exceptions |
11 | 23 | from typing_extensions import Protocol |
12 | 24 |
|
13 | | -from . import intersphinx, n, specparser |
| 25 | +from . import intersphinx, n, specparser, util |
14 | 26 | from .cache import Cache |
15 | 27 | from .n import FileId |
16 | 28 | from .types import ProjectConfig, normalize_target |
|
21 | 33 | #: current project, or a URL (from an intersphinx inventory). |
22 | 34 | TargetType = enum.Enum("TargetType", ("fileid", "url")) |
23 | 35 |
|
| 36 | +PAT_TARGET_PART_SEPARATOR = re.compile(r"[_-]+") |
| 37 | + |
24 | 38 |
|
25 | 39 | @dataclass |
26 | 40 | class TargetDatabase: |
@@ -68,9 +82,9 @@ def __getitem__(self, key: str) -> Sequence["TargetDatabase.Result"]: |
68 | 82 | canonical_target_name, |
69 | 83 | title, |
70 | 84 | ) |
71 | | - for canonical_target_name, fileid, title, html5_id in self.local_definitions[ |
72 | | - key |
73 | | - ] |
| 85 | + for canonical_target_name, fileid, title, html5_id in self.local_definitions.get( |
| 86 | + key, [] |
| 87 | + ) |
74 | 88 | ) |
75 | 89 | except KeyError: |
76 | 90 | pass |
@@ -107,6 +121,46 @@ def __getitem__(self, key: str) -> Sequence["TargetDatabase.Result"]: |
107 | 121 |
|
108 | 122 | return results |
109 | 123 |
|
| 124 | + def get_suggestions(self, key: str) -> Sequence[str]: |
| 125 | + key = normalize_target(key) |
| 126 | + key = key.split(":", 2)[2] |
| 127 | + candidates: List[str] = [] |
| 128 | + |
| 129 | + with self.lock: |
| 130 | + intersphinx_keys: Iterable[str] = itertools.chain.from_iterable( |
| 131 | + (str(s) for s in inventory.targets.keys()) |
| 132 | + for inventory in self.intersphinx_inventories.values() |
| 133 | + ) |
| 134 | + all_keys: Iterable[str] = itertools.chain( |
| 135 | + self.local_definitions.keys(), intersphinx_keys |
| 136 | + ) |
| 137 | + |
| 138 | + key_parts = PAT_TARGET_PART_SEPARATOR.split(key) |
| 139 | + |
| 140 | + for original_key_definition in all_keys: |
| 141 | + key_definition = original_key_definition.split(":", 2)[2] |
| 142 | + if abs(len(key) - len(key_definition)) > 2: |
| 143 | + continue |
| 144 | + |
| 145 | + # Tokens tend to be separated by - and _: if there's a different number of |
| 146 | + # separators, don't attempt a typo correction |
| 147 | + key_definition_parts = PAT_TARGET_PART_SEPARATOR.split(key_definition) |
| 148 | + if len(key_definition_parts) != len(key_parts): |
| 149 | + continue |
| 150 | + |
| 151 | + # Evaluate each part separately, since we can abort before evaluating the rest. |
| 152 | + # Small bonus: complexity is O(N*M) |
| 153 | + if all( |
| 154 | + dist <= 2 |
| 155 | + for dist in ( |
| 156 | + util.damerau_levenshtein_distance(p1, p2) |
| 157 | + for p1, p2 in zip(key_parts, key_definition_parts) |
| 158 | + ) |
| 159 | + ): |
| 160 | + candidates.append(original_key_definition) |
| 161 | + |
| 162 | + return candidates |
| 163 | + |
110 | 164 | def define_local_target( |
111 | 165 | self, |
112 | 166 | domain: str, |
|
0 commit comments