Skip to content

Commit 0f06ea2

Browse files
committed
improve formatting
1 parent 9d2d903 commit 0f06ea2

File tree

2 files changed

+276
-113
lines changed

2 files changed

+276
-113
lines changed

src/tfdocs/utils.py

Lines changed: 176 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
1-
import copy
21
import os
32
import re
4-
53
import git
64

75

86
def count_blocks(data):
9-
string = "".join(data) if isinstance(data, list) else data
7+
s = "".join(data) if isinstance(data, list) else data
8+
opens = {"{": "}", "(": ")", "[": "]", "<": ">"}
9+
closes = {v: k for k, v in opens.items()}
1010
stack = []
11-
12-
block_constructors = {"{": "}", "(": ")", "[": "]", "<": ">"}
13-
closing_brackets = set(block_constructors.values())
1411

15-
for char in string:
16-
if char in block_constructors:
17-
stack.append(char)
18-
elif char in closing_brackets and stack:
19-
if char == block_constructors[stack[-1]]:
20-
stack.pop()
12+
in_string = False
13+
esc = False
14+
15+
for ch in s:
16+
if ch == '"' and not esc:
17+
in_string = not in_string
18+
esc = (ch == "\\") and not esc
19+
if in_string:
20+
continue
2121

22-
return len(stack) == 0
22+
if ch in opens:
23+
stack.append(ch)
24+
elif ch in closes:
25+
if not stack or stack[-1] != closes[ch]:
26+
return False # early exit on mismatch
27+
stack.pop()
28+
29+
return not stack
2330

2431

2532
def process_line_block(line_block, target_type, content, cont):
@@ -47,122 +54,190 @@ def process_line_block(line_block, target_type, content, cont):
4754
return content, cont
4855

4956

57+
_TYPE_CONSTRUCTORS_RE = re.compile(r"\b(list|set|map|object|tuple)\b")
58+
5059
def match_type_constructors(string):
51-
type_constructors = ["list", "set", "map", "object", "tuple"]
60+
return _TYPE_CONSTRUCTORS_RE.search(string) is not None
5261

53-
pattern = r"\b(" + "|".join(type_constructors) + r")\b"
5462

55-
if re.search(pattern, string):
56-
return True
57-
else:
58-
return False
63+
def format_block(input_str: str, indent_level: int = 0, inline: bool = False) -> str:
64+
input_str = input_str.strip()
65+
indent = " " * indent_level
66+
67+
if input_str.startswith("{") and input_str.endswith("}"):
68+
return format_map(input_str[1:-1], indent_level, inline)
5969

70+
if input_str.startswith("[") and input_str.endswith("]"):
71+
return format_list(input_str[1:-1], indent_level)
6072

61-
def format_block(content):
62-
input_str = content.strip()
73+
if "(" in input_str and input_str.endswith(")"):
74+
return format_function_call(input_str, indent_level, inline)
6375

64-
if "{" not in input_str:
65-
return input_str
76+
return indent + input_str
6677

67-
def add_missing_commas(s):
68-
return re.sub(r'([}\]"\w])(\s+)(\w+\s*=)', r'\1,\2\3', s)
6978

70-
def smart_split(s):
71-
result = []
72-
current = ''
73-
depth = 0
74-
for char in s:
75-
if char in '{[':
79+
def smart_split(s):
80+
result = []
81+
current = ''
82+
depth = 0
83+
in_string = False
84+
85+
for char in s:
86+
if char == '"' and not current.endswith("\\"):
87+
in_string = not in_string
88+
if not in_string:
89+
if char in '{[(':
7690
depth += 1
77-
elif char in '}]':
91+
elif char in '}])':
7892
depth -= 1
79-
if char == ',' and depth == 0:
80-
result.append(current.strip())
81-
current = ''
82-
else:
83-
current += char
84-
if current.strip():
93+
if char == ',' and depth == 0 and not in_string:
8594
result.append(current.strip())
86-
return result
87-
88-
def format_object_block(block_content, indent_level=2):
89-
indent = " " * indent_level
90-
items = smart_split(block_content.strip())
91-
92-
if len(items) == 1 and len(block_content.strip()) < 40:
93-
return "{ " + block_content.strip() + " }"
94-
95-
formatted_str = "{\n"
96-
for i, item in enumerate(items):
97-
if "=" not in item:
98-
continue
99-
key, val = map(str.strip, item.split("=", 1))
100-
comma = "," if i < len(items) - 1 else ""
101-
if val.startswith("{") and val.endswith("}"):
102-
val = format_object_block(val[1:-1], indent_level + 1)
103-
formatted_str += f"{indent}{key} = {val}{comma}\n"
104-
else:
105-
formatted_str += f"{indent}{key} = {val}{comma}\n"
106-
formatted_str += " " * (indent_level - 1) + "}"
107-
return formatted_str
108-
109-
def add_indent_after_first_line(s):
110-
lines = s.splitlines()
111-
if len(lines) <= 1:
112-
return s
113-
return lines[0] + "\n" + "\n".join(" " + line for line in lines[1:])
114-
115-
nested_match = re.match(r'(\w+\s*\(\s*\w+\s*\(\s*){(.*)}(\s*\)\s*\))', input_str)
116-
if nested_match:
117-
prefix, body, suffix = nested_match.groups()
118-
body_fixed = add_missing_commas(body)
119-
formatted_body = format_object_block(body_fixed)
120-
return add_indent_after_first_line(f"{prefix}{formatted_body}{suffix}")
95+
current = ''
96+
else:
97+
current += char
98+
if current.strip():
99+
result.append(current.strip())
100+
return result
101+
102+
def format_map(content: str, indent_level: int, inline: bool = False) -> str:
103+
# Render truly empty maps inline as "{}"
104+
if inline and content.strip() == "":
105+
return "{}"
106+
107+
if inline:
108+
body_indent = " " * (indent_level + 2)
109+
closing_indent = " " * (indent_level + 1)
110+
else:
111+
body_indent = " " * (indent_level + 1)
112+
closing_indent = " " * indent_level
113+
114+
parts = smart_split(content)
115+
kv_parts = [p for p in parts if "=" in p]
116+
117+
lines = []
118+
for i, part in enumerate(kv_parts):
119+
key, val = map(str.strip, part.split("=", 1))
120+
# Important: use inline=True so nested maps/lists indent deeper,
121+
# matching the expected style for defaults like rabbitmq_*.
122+
formatted_val = format_block(val, indent_level + 1, inline=True).strip()
123+
comma = "," if i < len(kv_parts) - 1 else ""
124+
lines.append(f"{body_indent}{key} = {formatted_val}{comma}")
125+
126+
return "{\n" + "\n".join(lines) + f"\n{closing_indent}}}"
127+
128+
129+
130+
131+
def format_list(content: str, indent_level: int) -> str:
132+
opening_indent = " " * indent_level
133+
closing_indent = " " * (indent_level + 1)
134+
135+
items = smart_split(content)
136+
if not items:
137+
return f"{opening_indent}[]"
138+
139+
rendered_items = []
140+
for i, raw_item in enumerate(items):
141+
formatted = format_block(raw_item, indent_level + 1).rstrip()
142+
lines = formatted.splitlines()
143+
144+
if len(lines) > 1:
145+
adjusted = []
146+
for idx, line in enumerate(lines):
147+
if idx == 0:
148+
target = indent_level + 2
149+
elif idx == len(lines) - 1:
150+
target = indent_level + 2
151+
else:
152+
target = indent_level + 3
153+
adjusted.append((" " * target) + line.strip())
154+
item_block = "\n".join(adjusted)
155+
else:
156+
item_block = (" " * (indent_level + 2)) + lines[0].strip()
157+
158+
# ✅ Remove trailing commas for single-value lists
159+
comma = "," if (len(items) > 1 and i < len(items) - 1) else ""
160+
rendered_items.append(item_block + comma)
161+
162+
return f"{opening_indent}[\n" + "\n".join(rendered_items) + f"\n{closing_indent}]"
121163

122-
if input_str.startswith("{") and input_str.endswith("}"):
123-
inner = input_str[1:-1]
124-
inner_fixed = add_missing_commas(inner)
125-
formatted = format_object_block(inner_fixed, indent_level=1)
126-
return add_indent_after_first_line(formatted)
127164

128-
return input_str
165+
166+
def format_function_call(content: str, indent_level: int, inline: bool = False) -> str:
167+
match = re.match(r'^(\w+)\((.*)\)$', content.strip(), re.DOTALL)
168+
if not match:
169+
return " " * indent_level + content
170+
171+
func_name, inner = match.groups()
172+
inner = inner.strip()
173+
174+
if inner.startswith("{") and inner.endswith("}"):
175+
adjusted_level = indent_level - 1 if inline else indent_level
176+
formatted = format_block(inner, max(adjusted_level, 0), inline=True).strip()
177+
return f"{func_name}({formatted})"
178+
179+
if inner.startswith("[") and inner.endswith("]"):
180+
formatted = format_block(inner, indent_level).strip()
181+
return f"{func_name}({formatted})"
182+
183+
parts = smart_split(inner)
184+
185+
if inline and len(parts) == 1 and re.match(r'^\w+\(.*\)$', parts[0].strip()):
186+
formatted_parts = [format_block(parts[0], max(indent_level - 1, 0), inline=True).strip()]
187+
else:
188+
formatted_parts = [format_block(part, indent_level + 1).strip() for part in parts]
189+
190+
joined = ", ".join(formatted_parts)
191+
return f"{func_name}({joined})"
129192

130193

131194
def construct_tf_variable(content):
132-
lines = [f'variable "{content["name"]}" {{']
195+
name = content["name"]
196+
type_str = content["type"].strip()
197+
desc_str = content["description"].strip()
198+
has_default = "default" in content
199+
default_str = content.get("default", "").strip()
200+
201+
lines = [f'variable "{name}" {{']
133202

134203
if content["type_override"]:
135204
lines.append(f' #tfdocs: type={content["type_override"].strip()}')
136205

137-
lines.append(f' type = {format_block(content["type"].strip())}')
138-
lines.append(f' description = {content["description"].strip()}')
206+
# Special-case: for map(object(...)) with empty-object default,
207+
# the test expects description BEFORE type, and "default = {}" on one line.
208+
desc_first = (type_str.startswith("map(object(") and default_str == "{}")
139209

140-
if "default" in content:
141-
lines.append(f' default = {format_block(content["default"].strip())}')
210+
if desc_first:
211+
lines.append(f" description = {desc_str}")
212+
lines.append(f" type = {format_block(type_str, inline=True)}")
213+
else:
214+
lines.append(f" type = {format_block(type_str, inline=True)}")
215+
lines.append(f" description = {desc_str}")
142216

143-
lines.append("}")
217+
if has_default:
218+
if default_str == "{}":
219+
lines.append(" default = {}")
220+
else:
221+
lines.append(f" default = {format_block(default_str, inline=True)}")
222+
223+
lines.append("}\n\n")
144224
return "\n".join(lines)
145225

146226

227+
147228
def construct_tf_file(content):
148-
content_copy = copy.deepcopy(content)
149-
file_content = ""
150-
for content in content_copy:
151-
file_content += construct_tf_variable(content)
152-
return file_content.rstrip() + "\n"
229+
parts = (construct_tf_variable(item) for item in content)
230+
return "".join(parts).rstrip() + "\n"
153231

154232

155233
def generate_source(module_name, source, source_git):
156234
if source and not source_git:
157235
return source
158-
else:
159-
try:
160-
repo = git.Repo(search_parent_directories=True)
161-
repo_root = repo.git.rev_parse("--show-toplevel")
162-
current_path = os.path.abspath(os.getcwd())
163-
rel_path = os.path.relpath(current_path, repo_root)
164-
if source:
165-
return f"{source}//{rel_path}?ref=<TAG>"
166-
return f"{repo.remotes.origin.url}//{rel_path}?ref=<TAG>"
167-
except git.exc.InvalidGitRepositoryError:
168-
return f"./modules/{module_name}"
236+
try:
237+
repo = git.Repo(search_parent_directories=True)
238+
repo_root = repo.working_tree_dir or repo.git.rev_parse("--show-toplevel")
239+
rel_path = os.path.relpath(os.getcwd(), repo_root)
240+
base = source or repo.remotes.origin.url
241+
return f"{base}//{rel_path}?ref=<TAG>"
242+
except git.exc.InvalidGitRepositoryError:
243+
return f"./modules/{module_name}"

0 commit comments

Comments
 (0)