1
- from dataclasses import dataclass
1
+ from dataclasses import dataclass , field
2
+ from functools import cached_property
2
3
from pathlib import Path
3
4
4
- from tree_sitter_languages import get_language , get_parser
5
+ import tree_sitter_lua
6
+ import tree_sitter_markdown
7
+ from tree_sitter import Language , Parser
5
8
6
9
7
10
@dataclass (frozen = True )
8
11
class LuaClass :
9
12
name : str
10
- fields : list [str ]
13
+ fields : list [str ] = field ( default_factory = list )
11
14
12
- def is_it (self , name : str ) -> bool :
13
- return name in self .name
15
+ @cached_property
16
+ def class_name (self ) -> str :
17
+ # ---@class render.md.Init: render.md.Api -> Init
18
+ # ---@class (exact) render.md.UserBufferConfig -> UserBufferConfig
19
+ # ---@class (exact) render.md.UserConfig: render.md.UserBufferConfig -> UserConfig
20
+ return self .name .split (":" )[0 ].split ()[- 1 ].split ("." )[- 1 ]
21
+
22
+ @cached_property
23
+ def is_user (self ) -> bool :
24
+ return self .class_name .startswith ("User" )
14
25
15
26
def is_optional (self , field : str ) -> bool :
16
- optional_fields : list [ str ] = [ " extends" , "scope_highlight" , "quote_icon" ]
17
- for optional_field in optional_fields :
18
- if optional_field in field :
19
- return True
20
- return False
27
+ # ---@field public extends? boolean -> extends
28
+ # ---@field public start_row integer -> start_row
29
+ # ---@field public attach? fun(buf: integer) -> attach
30
+ field_name = field . split ()[ 2 ]. replace ( "?" , "" )
31
+ return field_name in [ "extends" , "scope_highlight" , "quote_icon" ]
21
32
22
33
def validate (self ) -> None :
23
34
for field in self .fields :
24
- if self .is_it ("User" ):
25
- self .validate_user_field (field )
26
- else :
27
- self .validate_non_user_field (field )
28
-
29
- def validate_user_field (self , field : str ) -> None :
30
- # User classes are expected to have optional fields
31
- assert "?" in field , f"Field must be optional: { field } "
32
-
33
- def validate_non_user_field (self , field : str ) -> None :
34
- # Non user classes are expected to have mandatory fields with some exceptions
35
- if not self .is_optional (field ):
36
- assert "?" not in field , f"Field must be mandatory: { field } "
37
-
38
- def to_public_lines (self ) -> list [str ]:
39
- if not self .is_it ("User" ):
40
- return []
41
-
42
- lines : list [str ] = []
43
- lines .append (self .name .replace ("User" , "" ))
35
+ # User classes are expected to have optional fields with no exceptions
36
+ # Internal classes are expected to have mandatory fields with some exceptions
37
+ optional = self .is_user or self .is_optional (field )
38
+ message = "optional" if optional else "mandatory"
39
+ assert ("?" in field ) == optional , f"Field must be { message } : { field } "
40
+
41
+ def to_internal (self ) -> str :
42
+ lines : list [str ] = [self .name .replace ("User" , "" )]
44
43
for field in self .fields :
45
- if self .is_it ( "ConfigOverrides" ) :
44
+ if self .class_name == "UserConfigOverrides" :
46
45
lines .append (field .replace ("?" , "" ))
47
46
elif self .is_optional (field ):
48
47
lines .append (field )
49
48
else :
50
49
lines .append (field .replace ("User" , "" ).replace ("?" , "" ))
51
- lines .append ("" )
52
- return lines
50
+ return "\n " .join (lines )
53
51
54
52
def to_str (self ) -> str :
55
- lines : list [str ] = [self .name ]
56
- lines .extend (self .fields )
57
- return "\n " .join (lines )
53
+ return "\n " .join ([self .name ] + self .fields )
58
54
59
- @staticmethod
60
- def from_lines (lines : list [str ]) -> "LuaClass" :
61
- return LuaClass (name = lines [0 ], fields = lines [1 :])
55
+
56
+ INIT_LUA = Path ("lua/render-markdown/init.lua" )
57
+ TYPES_LUA = Path ("lua/render-markdown/types.lua" )
58
+ README_MD = Path ("README.md" )
59
+ HANDLERS_MD = Path ("doc/custom-handlers.md" )
62
60
63
61
64
62
def main () -> None :
65
- init_file = Path ("lua/render-markdown/init.lua" )
66
- update_types (init_file , Path ("lua/render-markdown/types.lua" ))
67
- update_readme (init_file , Path ("README.md" ))
68
- update_custom_handlers (init_file , Path ("doc/custom-handlers.md" ))
63
+ update_types ()
64
+ update_readme ()
65
+ update_handlers ()
69
66
70
67
71
- def update_types (init_file : Path , types_file : Path ) -> None :
72
- lines : list [str ] = ["---@meta" , " " ]
73
- for lua_class in get_classes (init_file ):
68
+ def update_types () -> None :
69
+ classes : list [str ] = ["---@meta" ]
70
+ for lua_class in get_classes ():
74
71
lua_class .validate ()
75
- lines .extend (lua_class .to_public_lines ())
76
- types_file .write_text ("\n " .join (lines ))
72
+ if lua_class .is_user :
73
+ classes .append (lua_class .to_internal ())
74
+ TYPES_LUA .write_text ("\n \n " .join (classes ) + "\n " )
75
+
77
76
77
+ def update_readme () -> None :
78
78
79
- def update_readme (init_file : Path , readme_file : Path ) -> None :
80
- old_config = get_code_block (readme_file , "log_level" , 1 )
81
- new_config = wrap_setup (get_default_config (init_file ))
82
- text = readme_file .read_text ().replace (old_config , new_config )
79
+ def wrap_setup (value : str ) -> str :
80
+ return f"require('render-markdown').setup({ value } )\n "
81
+
82
+ old_config = get_code_block (README_MD , "log_level" , 1 )
83
+ new_config = wrap_setup (get_default_config ())
84
+ text = README_MD .read_text ().replace (old_config , new_config )
83
85
84
86
parameters : list [str ] = [
85
87
"heading" ,
@@ -96,52 +98,41 @@ def update_readme(init_file: Path, readme_file: Path) -> None:
96
98
"indent" ,
97
99
]
98
100
for parameter in parameters :
99
- old_param = get_code_block (readme_file , f"\n { parameter } = {{" , 2 )
101
+ old_param = get_code_block (README_MD , f"\n { parameter } = {{" , 2 )
100
102
new_param = wrap_setup (get_config_for (new_config , parameter ))
101
103
text = text .replace (old_param , new_param )
102
104
103
- readme_file .write_text (text )
105
+ README_MD .write_text (text )
104
106
105
107
106
- def update_custom_handlers (init_file : Path , handler_file : Path ) -> None :
107
- class_name : str = "render.md.Handler"
108
- old = get_code_block (handler_file , class_name , 1 )
109
- new = [
110
- get_class (init_file , "render.md.Mark" ).to_str (),
111
- "" ,
112
- get_class (init_file , class_name ).to_str (),
113
- ]
114
- text = handler_file .read_text ().replace (old , "\n " .join (new ))
115
- handler_file .write_text (text )
108
+ def update_handlers () -> None :
109
+ name_to_lua = {lua .class_name : lua for lua in get_classes ()}
110
+ mark = name_to_lua ["Mark" ]
111
+ handler = name_to_lua ["Handler" ]
116
112
113
+ old = get_code_block (HANDLERS_MD , mark .name , 1 )
114
+ new = "\n " .join ([mark .to_str (), "" , handler .to_str (), "" ])
115
+ text = HANDLERS_MD .read_text ().replace (old , new )
116
+ HANDLERS_MD .write_text (text )
117
117
118
- def get_class (init_file : Path , name : str ) -> LuaClass :
119
- lua_classes = get_classes (init_file )
120
- results = [lua_class for lua_class in lua_classes if name in lua_class .name ]
121
- assert len (results ) == 1
122
- return results [0 ]
123
118
124
-
125
- def get_classes (init_file : Path ) -> list [LuaClass ]:
126
- # Group comments into class + fields
119
+ def get_classes () -> list [LuaClass ]:
127
120
lua_classes : list [LuaClass ] = []
128
- current : list [str ] = []
129
- for comment in get_comments (init_file ):
130
- comment_type : str = comment .split ()[0 ].split ("@" )[- 1 ]
131
- if comment_type == "class" :
132
- if len (current ) > 0 :
133
- lua_classes .append (LuaClass .from_lines (current ))
134
- current = [comment ]
135
- elif comment_type == "field" :
136
- current .append (comment )
137
- lua_classes .append (LuaClass .from_lines (current ))
121
+ for comment in get_comments ():
122
+ # ---@class render.md.Init: render.md.Api -> class
123
+ # ---@field public enabled? boolean -> field
124
+ # ---@alias render.md.code.Width 'full'|'block' -> alias
125
+ # ---@type render.md.Config -> type
126
+ # ---@param opts? render.md.UserConfig -> param
127
+ # -- Inlined with 'image' elements -> --
128
+ annotation = comment .split ()[0 ].split ("@" )[- 1 ]
129
+ if annotation == "class" :
130
+ lua_classes .append (LuaClass (comment ))
131
+ elif annotation == "field" :
132
+ lua_classes [- 1 ].fields .append (comment )
138
133
return lua_classes
139
134
140
135
141
- def wrap_setup (value : str ) -> str :
142
- return f"require('render-markdown').setup({ value } )"
143
-
144
-
145
136
def get_config_for (config : str , parameter : str ) -> str :
146
137
lines : list [str ] = config .splitlines ()
147
138
start : int = lines .index (f" { parameter } = {{" )
@@ -153,22 +144,24 @@ def get_config_for(config: str, parameter: str) -> str:
153
144
return "\n " .join (["{" ] + lines [start : end + 1 ] + ["}" ])
154
145
155
146
156
- def get_comments (file : Path ) -> list [str ]:
147
+ def get_comments () -> list [str ]:
157
148
query = "(comment) @comment"
158
- return ts_query (file , query , "comment" )
149
+ return ts_query (INIT_LUA , query , "comment" )
159
150
160
151
161
- def get_default_config (file : Path ) -> str :
152
+ def get_default_config () -> str :
162
153
query = """
163
- (variable_assignment(
164
- (variable_list(
165
- variable field: (identifier) @name
166
- (#eq? @name "default_config")
167
- ))
168
- (expression_list value: (table)) @value
169
- ))
154
+ (assignment_statement
155
+ (variable_list
156
+ name: (dot_index_expression
157
+ field: (identifier) @name
158
+ (#eq? @name "default_config")
159
+ )
160
+ )
161
+ (expression_list value: (table_constructor)) @value
162
+ )
170
163
"""
171
- default_configs = ts_query (file , query , "value" )
164
+ default_configs = ts_query (INIT_LUA , query , "value" )
172
165
assert len (default_configs ) == 1
173
166
return default_configs [0 ]
174
167
@@ -182,21 +175,19 @@ def get_code_block(file: Path, content: str, n: int) -> str:
182
175
183
176
184
177
def ts_query (file : Path , query : str , target : str ) -> list [str ]:
185
- ts_language : str = {
186
- ".lua" : "lua" ,
187
- ".md" : "markdown" ,
178
+ tree_sitter = {
179
+ ".lua" : tree_sitter_lua ,
180
+ ".md" : tree_sitter_markdown ,
188
181
}[file .suffix ]
189
- parser = get_parser (ts_language )
190
- tree = parser .parse (file .read_text ().encode ())
191
182
192
- ts_query = get_language (ts_language ).query (query )
193
- captures = ts_query .captures (tree .root_node )
183
+ language = Language (tree_sitter .language ())
184
+ tree = Parser (language ).parse (file .read_text ().encode ())
185
+ captures = language .query (query ).captures (tree .root_node )
194
186
195
- values : list [str ] = []
196
- for node , capture in captures :
197
- if capture == target :
198
- values .append (node .text .decode ())
199
- return values
187
+ nodes = captures [target ]
188
+ nodes .sort (key = lambda node : node .start_byte )
189
+ texts = [node .text for node in nodes ]
190
+ return [text .decode () for text in texts if text is not None ]
200
191
201
192
202
193
if __name__ == "__main__" :
0 commit comments