2727#
2828
2929# stdlib
30- from ast import Tuple
30+ import ast
31+ from collections import defaultdict
32+ from typing import Tuple , Type
3133
3234# 3rd party
3335from tokenize_rt import src_to_tokens , tokens_to_src
4648
4749__all__ = ["trailing_commas_hook" ]
4850
51+ # TODO: leave closing commas alone
4952
50- def trailing_commas_hook (source : str , ** kwargs ) -> str :
53+
54+ def trailing_commas_hook (
55+ source : str ,
56+ min_version : Tuple [int , int ] = (3 , 6 ),
57+ ** kwargs ,
58+ ) -> str :
5159 r"""
5260 Call `add-trailing-comma <https://github.com/asottile/add-trailing-comma>`_, using the given keyword arguments as its configuration.
5361
@@ -58,20 +66,31 @@ def trailing_commas_hook(source: str, **kwargs) -> str:
5866 """
5967
6068 ast_obj = ast_parse (source )
61- min_version : Tuple [int , int ] = kwargs .get ("min-version" , (3 , 6 ))
6269
63- callbacks = visit (FUNCS , ast_obj , min_version )
70+ enabled_funcs = defaultdict (list )
71+
72+ ast_class : Type [ast .AST ]
73+ for ast_class , func in FUNCS .items ():
74+ class_name = ast_class .__name__
75+ option_name = f"format_{ class_name } "
76+
77+ enabled : bool = kwargs .get (option_name .lower (), kwargs .get (option_name , True ))
78+
79+ if enabled :
80+ enabled_funcs [ast_class ] = func
81+
82+ callbacks = visit (enabled_funcs , ast_obj , min_version )
6483
6584 tokens = src_to_tokens (source )
6685 for i , token in _changing_list (tokens ):
67- # DEDENT is a zero length token
68- if not token .src :
69- continue
86+ # # DEDENT is a zero length token
87+ # if not token.src:
88+ # continue
7089
7190 for callback in callbacks .get (token .offset , ()):
7291 callback (i , tokens )
7392
74- if token .src in START_BRACES :
75- fix_brace (tokens , find_simple (i , tokens ), add_comma = False , remove_comma = False )
93+ # if token.src in START_BRACES:
94+ # fix_brace(tokens, find_simple(i, tokens), add_comma=False, remove_comma=False)
7695
7796 return tokens_to_src (tokens )
0 commit comments