77import click
88from beancount_black .formatter import Formatter
99from beancount_parser .parser import make_parser
10+ from beancount_parser .parser import traverse
11+ from lark import Lark
12+ from lark import Tree
1013
1114from .cli import cli
1215from .environment import Environment
@@ -32,6 +35,16 @@ def create_backup(src: pathlib.Path, suffix: str) -> pathlib.Path:
3235 return backup_path
3336
3437
38+ def file_tree_iterator (
39+ parser : Lark , filenames : list [pathlib .Path ]
40+ ) -> typing .Generator [tuple [pathlib .Path , Tree ], None , None ]:
41+ for filename in filenames :
42+ with open (filename , "rt" ) as input_file :
43+ input_content = input_file .read ()
44+ tree = parser .parse (input_content )
45+ yield filename , tree
46+
47+
3548@cli .command (name = "format" , help = "Format Beancount files with beancount-black" )
3649@click .argument ("filename" , type = click .Path (exists = False , dir_okay = False ), nargs = - 1 )
3750@click .option (
@@ -53,33 +66,43 @@ def main(
5366 backup : bool ,
5467):
5568 # TODO: support follow include statements
56-
5769 parser = make_parser ()
58- formatter = Formatter ()
5970 if stdin_mode :
6071 env .logger .info ("Processing in stdin mode" )
6172 input_content = sys .stdin .read ()
6273 tree = parser .parse (input_content )
74+ formatter = Formatter ()
6375 formatter .format (tree , sys .stdout )
6476 else :
65- for name in filename :
66- env .logger .info ("Processing file %s" , name )
67- with open (name , "rt" ) as input_file :
68- input_content = input_file .read ()
69- tree = parser .parse (input_content )
77+ if filename :
78+ iterator = file_tree_iterator (
79+ parser = parser ,
80+ filenames = map (lambda item : pathlib .Path (str (item )), filename ),
81+ )
82+ else :
83+ env .logger .info ("No files provided, traverse starting from main.bean" )
84+ iterator = traverse (
85+ parser = parser ,
86+ bean_file = pathlib .Path ("main.bean" ),
87+ root_dir = pathlib .Path .cwd (),
88+ )
89+ for filepath , tree in iterator :
90+ env .logger .info ("Processing file %s" , filepath )
7091 with tempfile .NamedTemporaryFile (mode = "wt+" , suffix = ".bean" ) as output_file :
92+ formatter = Formatter ()
7193 formatter .format (tree , output_file )
7294 output_file .seek (0 )
7395 output_content = output_file .read ()
96+ input_content = filepath .read_text ()
7497 if input_content == output_content :
75- env .logger .info ("File %s is not changed, skip" , name )
98+ env .logger .info ("File %s is not changed, skip" , filepath )
7699 continue
77100 if backup :
78- backup_path = create_backup (
79- src = pathlib .Path (str (name )), suffix = backup_suffix
101+ backup_path = create_backup (src = filepath , suffix = backup_suffix )
102+ env .logger .info (
103+ "File %s changed, backup to %s" , filepath , backup_path
80104 )
81- env .logger .info ("File %s changed, backup to %s" , name , backup_path )
82105 output_file .seek (0 )
83- with open (name , "wt" ) as input_file :
106+ with open (filepath , "wt" ) as input_file :
84107 shutil .copyfileobj (output_file , input_file )
85108 env .logger .info ("done" )
0 commit comments