Skip to content

Commit 5814295

Browse files
authored
Merge pull request #5 from LaunchPlatform/traverse-format
Traverse main by default if no file name provided
2 parents e9c5679 + 0219db5 commit 5814295

File tree

4 files changed

+88
-24
lines changed

4 files changed

+88
-24
lines changed

beanhub_cli/format.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import click
88
from beancount_black.formatter import Formatter
99
from beancount_parser.parser import make_parser
10+
from beancount_parser.parser import traverse
11+
from lark import Lark
12+
from lark import Tree
1013

1114
from .cli import cli
1215
from .environment import Environment
@@ -32,6 +35,16 @@ def create_backup(src: pathlib.Path, suffix: str) -> pathlib.Path:
3235
return backup_path
3336

3437

38+
def file_tree_iterator(
39+
parser: Lark, filenames: list[pathlib.Path]
40+
) -> typing.Generator[tuple[pathlib.Path, Tree], None, None]:
41+
for filename in filenames:
42+
with open(filename, "rt") as input_file:
43+
input_content = input_file.read()
44+
tree = parser.parse(input_content)
45+
yield filename, tree
46+
47+
3548
@cli.command(name="format", help="Format Beancount files with beancount-black")
3649
@click.argument("filename", type=click.Path(exists=False, dir_okay=False), nargs=-1)
3750
@click.option(
@@ -53,33 +66,43 @@ def main(
5366
backup: bool,
5467
):
5568
# TODO: support follow include statements
56-
5769
parser = make_parser()
58-
formatter = Formatter()
5970
if stdin_mode:
6071
env.logger.info("Processing in stdin mode")
6172
input_content = sys.stdin.read()
6273
tree = parser.parse(input_content)
74+
formatter = Formatter()
6375
formatter.format(tree, sys.stdout)
6476
else:
65-
for name in filename:
66-
env.logger.info("Processing file %s", name)
67-
with open(name, "rt") as input_file:
68-
input_content = input_file.read()
69-
tree = parser.parse(input_content)
77+
if filename:
78+
iterator = file_tree_iterator(
79+
parser=parser,
80+
filenames=map(lambda item: pathlib.Path(str(item)), filename),
81+
)
82+
else:
83+
env.logger.info("No files provided, traverse starting from main.bean")
84+
iterator = traverse(
85+
parser=parser,
86+
bean_file=pathlib.Path("main.bean"),
87+
root_dir=pathlib.Path.cwd(),
88+
)
89+
for filepath, tree in iterator:
90+
env.logger.info("Processing file %s", filepath)
7091
with tempfile.NamedTemporaryFile(mode="wt+", suffix=".bean") as output_file:
92+
formatter = Formatter()
7193
formatter.format(tree, output_file)
7294
output_file.seek(0)
7395
output_content = output_file.read()
96+
input_content = filepath.read_text()
7497
if input_content == output_content:
75-
env.logger.info("File %s is not changed, skip", name)
98+
env.logger.info("File %s is not changed, skip", filepath)
7699
continue
77100
if backup:
78-
backup_path = create_backup(
79-
src=pathlib.Path(str(name)), suffix=backup_suffix
101+
backup_path = create_backup(src=filepath, suffix=backup_suffix)
102+
env.logger.info(
103+
"File %s changed, backup to %s", filepath, backup_path
80104
)
81-
env.logger.info("File %s changed, backup to %s", name, backup_path)
82105
output_file.seek(0)
83-
with open(name, "wt") as input_file:
106+
with open(filepath, "wt") as input_file:
84107
shutil.copyfileobj(output_file, input_file)
85108
env.logger.info("done")

tests/forms/test_validator.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import contextlib
2-
import os
31
import pathlib
42

53
import pytest
@@ -8,22 +6,13 @@
86
from click.testing import CliRunner
97
from pydantic import ValidationError
108

9+
from ..helper import switch_cwd
1110
from beanhub_cli.forms.validator import format_loc
1211
from beanhub_cli.forms.validator import merge_index_loc
1312
from beanhub_cli.forms.validator import validate_doc
1413
from beanhub_cli.main import cli
1514

1615

17-
@contextlib.contextmanager
18-
def switch_cwd(cwd: pathlib.Path):
19-
current_cwd = pathlib.Path.cwd()
20-
try:
21-
os.chdir(cwd)
22-
yield
23-
finally:
24-
os.chdir(current_cwd)
25-
26-
2716
@pytest.mark.parametrize(
2817
"loc, expected",
2918
[
@@ -118,9 +107,11 @@ def test_bad_schema(tmp_path: pathlib.Path, schema: dict, expected_errors: list)
118107
yaml.dump(schema, fo)
119108
with pytest.raises(ValidationError) as exc:
120109
validate_doc(doc_file)
110+
121111
def del_url(d):
122112
del d["url"]
123113
return d
114+
124115
assert list(map(del_url, exc.value.errors())) == expected_errors
125116

126117

tests/helper.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import contextlib
2+
import os
3+
import pathlib
4+
5+
6+
@contextlib.contextmanager
7+
def switch_cwd(cwd: pathlib.Path):
8+
current_cwd = pathlib.Path.cwd()
9+
try:
10+
os.chdir(cwd)
11+
yield
12+
finally:
13+
os.chdir(current_cwd)

tests/test_format.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pathlib
2+
3+
from click.testing import CliRunner
4+
5+
from .helper import switch_cwd
6+
from beanhub_cli.main import cli
7+
8+
9+
def test_format_cmd(tmp_path: pathlib.Path, cli_runner: CliRunner):
10+
beanhub_dir = tmp_path / ".beanhub"
11+
beanhub_dir.mkdir()
12+
13+
bean_file = beanhub_dir / "sample.bean"
14+
bean_file.write_text("2024-06-27 open Assets:Cash")
15+
16+
cli_runner.mix_stderr = False
17+
with switch_cwd(tmp_path):
18+
result = cli_runner.invoke(cli, ["format", str(bean_file)])
19+
assert result.exit_code == 0
20+
assert bean_file.read_text() == "2024-06-27 open Assets:Cash\n"
21+
22+
23+
def test_format_cmd_without_args(tmp_path: pathlib.Path, cli_runner: CliRunner):
24+
beanhub_dir = tmp_path / ".beanhub"
25+
beanhub_dir.mkdir()
26+
27+
included_bean = beanhub_dir / "mybook.bean"
28+
included_bean.write_text("2024-06-27 open Assets:Cash")
29+
30+
main_bean = beanhub_dir / "main.bean"
31+
main_bean.write_text('include "mybook.bean"')
32+
33+
cli_runner.mix_stderr = False
34+
with switch_cwd(beanhub_dir):
35+
result = cli_runner.invoke(cli, ["format"])
36+
assert result.exit_code == 0
37+
assert included_bean.read_text() == "2024-06-27 open Assets:Cash\n"

0 commit comments

Comments
 (0)