|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import enum |
| 6 | +import importlib |
6 | 7 | import random |
7 | 8 | import re |
8 | 9 | import string |
| 10 | +import subprocess |
9 | 11 | import sys |
10 | 12 | import warnings |
11 | 13 | from functools import lru_cache |
@@ -83,7 +85,7 @@ def do_format_code(code: str, line_length: int) -> str: |
83 | 85 | code = code.strip() |
84 | 86 | if len(code) < line_length: |
85 | 87 | return code |
86 | | - formatter = _get_black_formatter() |
| 88 | + formatter = _get_formatter() |
87 | 89 | return formatter(code, line_length) |
88 | 90 |
|
89 | 91 |
|
@@ -118,7 +120,7 @@ def _format_signature(name: Markup, signature: str, line_length: int) -> str: |
118 | 120 | # Black cannot format names with dots, so we replace |
119 | 121 | # the whole name with a string of equal length |
120 | 122 | name_length = len(name) |
121 | | - formatter = _get_black_formatter() |
| 123 | + formatter = _get_formatter() |
122 | 124 | formatable = f"def {'x' * name_length}{signature}: pass" |
123 | 125 | formatted = formatter(formatable, line_length) |
124 | 126 |
|
@@ -434,12 +436,53 @@ def do_filter_objects( |
434 | 436 |
|
435 | 437 |
|
436 | 438 | @lru_cache(maxsize=1) |
437 | | -def _get_black_formatter() -> Callable[[str, int], str]: |
| 439 | +def _get_formatter() -> Callable[[str, int], str]: |
| 440 | + for formatter_function in [ |
| 441 | + _get_black_formatter, |
| 442 | + _get_ruff_formatter, |
| 443 | + ]: |
| 444 | + if (formatter := formatter_function()) is not None: |
| 445 | + return formatter |
| 446 | + |
| 447 | + logger.info("Formatting signatures requires either Black or ruff to be installed.") |
| 448 | + return lambda text, _: text |
| 449 | + |
| 450 | + |
| 451 | +@lru_cache(maxsize=1) |
| 452 | +def _get_ruff_formatter() -> Callable[[str, int], str] | None: |
| 453 | + if importlib.util.find_spec("ruff") is None: |
| 454 | + return None |
| 455 | + |
| 456 | + def formatter(code: str, line_length: int) -> str: |
| 457 | + try: |
| 458 | + completed_process = subprocess.run( # noqa: S603 |
| 459 | + [ # noqa: S607 |
| 460 | + "ruff", |
| 461 | + "format", |
| 462 | + f'--config "line-length={line_length}"', |
| 463 | + "--stdin-filename", |
| 464 | + "file.py", |
| 465 | + "-", |
| 466 | + ], |
| 467 | + check=True, |
| 468 | + capture_output=True, |
| 469 | + text=True, |
| 470 | + input=code, |
| 471 | + ) |
| 472 | + except subprocess.CalledProcessError: |
| 473 | + return code |
| 474 | + else: |
| 475 | + return completed_process.stdout |
| 476 | + |
| 477 | + return formatter |
| 478 | + |
| 479 | + |
| 480 | +@lru_cache(maxsize=1) |
| 481 | +def _get_black_formatter() -> Callable[[str, int], str] | None: |
438 | 482 | try: |
439 | 483 | from black import InvalidInput, Mode, format_str |
440 | 484 | except ModuleNotFoundError: |
441 | | - logger.info("Formatting signatures requires Black to be installed.") |
442 | | - return lambda text, _: text |
| 485 | + return None |
443 | 486 |
|
444 | 487 | def formatter(code: str, line_length: int) -> str: |
445 | 488 | mode = Mode(line_length=line_length) |
|
0 commit comments