|
4 | 4 | import textwrap |
5 | 5 | import uuid |
6 | 6 | from datetime import datetime |
| 7 | +from pathlib import Path |
7 | 8 | from typing import Any |
8 | 9 |
|
| 10 | +import pfzy |
9 | 11 | from discord import Color, Embed, Interaction, app_commands |
| 12 | +from discord.app_commands import Range |
10 | 13 | from discord.ext.commands import GroupCog |
11 | 14 | from discord.ui import Button, View |
12 | 15 | from githubkit import GitHub |
13 | 16 | from githubkit.exception import GitHubException |
14 | 17 | from githubkit.rest import Issue, PullRequest, SimpleUser |
| 18 | +from more_itertools import consecutive_groups |
| 19 | +from yarl import URL |
15 | 20 |
|
16 | | -from ghutils.core.cog import GHUtilsCog |
| 21 | +from ghutils.core.cog import GHUtilsCog, SubGroup |
| 22 | +from ghutils.core.types import LoginState, NotLoggedInError |
17 | 23 | from ghutils.db.models import ( |
18 | 24 | UserGitHubTokens, |
19 | 25 | UserLogin, |
@@ -244,6 +250,122 @@ async def status( |
244 | 250 |
|
245 | 251 | await respond_with_visibility(interaction, visibility, embed=embed) |
246 | 252 |
|
| 253 | + class Search(SubGroup): |
| 254 | + """Search for things on GitHub.""" |
| 255 | + |
| 256 | + @app_commands.command() |
| 257 | + async def files( |
| 258 | + self, |
| 259 | + interaction: Interaction, |
| 260 | + repo: FullRepositoryOption, |
| 261 | + query: Range[str, 1, 128], |
| 262 | + ref: Range[str, 1, 255] | None = None, |
| 263 | + exact: bool = False, |
| 264 | + limit: Range[int, 1, 25] = 5, |
| 265 | + visibility: MessageVisibility = "private", |
| 266 | + ): |
| 267 | + """Search for files in a repository by name. |
| 268 | +
|
| 269 | + Args: |
| 270 | + ref: Branch name, tag name, or commit to search in. Defaults to the |
| 271 | + default branch of the repo. |
| 272 | + exact: If true, use exact search; otherwise use fuzzy search. |
| 273 | + limit: Maximum number of results to show. |
| 274 | + """ |
| 275 | + |
| 276 | + async with self.bot.github_app(interaction) as (github, state): |
| 277 | + if state != LoginState.LOGGED_IN: |
| 278 | + raise NotLoggedInError() |
| 279 | + |
| 280 | + if ref is None: |
| 281 | + ref = repo.default_branch |
| 282 | + |
| 283 | + tree = await gh_request( |
| 284 | + github.rest.git.async_get_tree( |
| 285 | + repo.owner.login, |
| 286 | + repo.name, |
| 287 | + ref, |
| 288 | + recursive="1", |
| 289 | + ) |
| 290 | + ) |
| 291 | + |
| 292 | + sha = tree.sha[:12] |
| 293 | + tree_dict = {item.path: item for item in tree.tree if item.path} |
| 294 | + |
| 295 | + matches = await pfzy.fuzzy_match( |
| 296 | + query, |
| 297 | + list(tree_dict.keys()), |
| 298 | + scorer=pfzy.substr_scorer if exact else pfzy.fzy_scorer, |
| 299 | + ) |
| 300 | + |
| 301 | + embed = ( |
| 302 | + Embed( |
| 303 | + title="File search results", |
| 304 | + ) |
| 305 | + .set_author( |
| 306 | + name=repo.full_name, |
| 307 | + url=repo.html_url, |
| 308 | + icon_url=repo.owner.avatar_url, |
| 309 | + ) |
| 310 | + .set_footer( |
| 311 | + text=f"{repo.full_name}@{ref} • Total results: {len(matches)}", |
| 312 | + ) |
| 313 | + ) |
| 314 | + |
| 315 | + # code search only works on the default branch |
| 316 | + # so don't add the link otherwise, since it won't be useful |
| 317 | + if ref == repo.default_branch: |
| 318 | + embed.url = str( |
| 319 | + URL("https://github.com/search").with_query( |
| 320 | + type="code", |
| 321 | + q=f'repo:{repo.full_name} path:"{query}"', |
| 322 | + ) |
| 323 | + ) |
| 324 | + |
| 325 | + if matches: |
| 326 | + embed.color = Color.green() |
| 327 | + else: |
| 328 | + embed.description = "⚠️ No matches found." |
| 329 | + embed.color = Color.red() |
| 330 | + |
| 331 | + size = 0 |
| 332 | + for match in matches[:limit]: |
| 333 | + path: str = match["value"] |
| 334 | + indices: list[int] = match["indices"] |
| 335 | + |
| 336 | + item = tree_dict[path] |
| 337 | + |
| 338 | + icon = "📁" if item.type == "tree" else "📄" |
| 339 | + url = f"https://github.com/{repo.full_name}/{item.type}/{sha}/{item.path}" |
| 340 | + |
| 341 | + parts = list[str]() |
| 342 | + index = 0 |
| 343 | + for group in consecutive_groups(indices): |
| 344 | + group = list(group) |
| 345 | + parts += [ |
| 346 | + # everything before the start of the group |
| 347 | + path[index : group[0]], |
| 348 | + "**", |
| 349 | + # everything in the group |
| 350 | + path[group[0] : group[-1] + 1], |
| 351 | + "**", |
| 352 | + ] |
| 353 | + index = group[-1] + 1 |
| 354 | + # everything after the last group |
| 355 | + parts.append(path[index:]) |
| 356 | + highlighted_path = "".join(parts) |
| 357 | + |
| 358 | + name = f"{icon} {Path(path).name}" |
| 359 | + value = f"[{highlighted_path}]({url})" |
| 360 | + |
| 361 | + size += len(name) + len(value) |
| 362 | + if size > 5000: |
| 363 | + break |
| 364 | + |
| 365 | + embed.add_field(name=name, value=value, inline=False) |
| 366 | + |
| 367 | + await respond_with_visibility(interaction, visibility, embed=embed) |
| 368 | + |
247 | 369 |
|
248 | 370 | def _discord_date(timestamp: int | float | datetime): |
249 | 371 | match timestamp: |
|
0 commit comments