Skip to content

Commit d7daa11

Browse files
committed
stubfic script placeholder
1 parent 03fd876 commit d7daa11

File tree

1 file changed

+360
-0
lines changed

1 file changed

+360
-0
lines changed

scripts/stubfix.py

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
import re
2+
import typing as py_typing
3+
from pathlib import Path
4+
from typing import Optional
5+
6+
7+
def _fix_exits(text: str) -> str:
8+
return re.sub(
9+
r"def __exit__\([^)]+\)[^:]+:",
10+
"def __exit__(self, exc_type: object, exc: object, traceback: object) -> None:",
11+
text,
12+
flags=re.DOTALL,
13+
)
14+
15+
16+
def _fix_imports(text: str) -> str:
17+
"""Ensure 'from . import typing as duckdb_typing' exists (replace plain 'from . import typing'),
18+
and ensure 'import typing' (stdlib) is present somewhere.
19+
"""
20+
duckdb_typing_import = "from . import typing as duckdb_typing"
21+
# Replace exact `from . import typing` occurrences (line anchored)
22+
text = re.sub(r"(?m)^\s*from \. import typing\s*$", duckdb_typing_import, text)
23+
24+
# Ensure stdlib typing present somewhere
25+
if not re.search(r"(?m)^\s*(import typing|from typing import\b)", text):
26+
lines = text.splitlines()
27+
for i, ln in enumerate(lines):
28+
if ln.strip() == duckdb_typing_import:
29+
lines.insert(i, "import typing")
30+
text = "\n".join(lines)
31+
break
32+
else:
33+
# If we didn't find the duckdb_typing import, put stdlib import at top
34+
text = "import typing\n" + text
35+
return text
36+
37+
38+
# -------------------------
39+
# Overload consolidation
40+
# -------------------------
41+
def _collect_block_boundaries(lines: list[str], start: int) -> int:
42+
"""Given a starting index pointing at a decorator line (starting with '@'),
43+
return the index just after the entire function block (decorators + def + body).
44+
If def not found after decorators, return start+1 (conservative).
45+
"""
46+
n = len(lines)
47+
j = start
48+
# skip contiguous decorator lines
49+
while j < n and lines[j].lstrip().startswith("@"):
50+
j += 1
51+
if j >= n:
52+
return start + 1
53+
# must find a def
54+
if not lines[j].lstrip().startswith("def "):
55+
return start + 1
56+
def_line = lines[j]
57+
def_indent = len(def_line) - len(def_line.lstrip(" "))
58+
j += 1
59+
# collect body: lines with indentation > def_indent (allow blank lines that are followed by indented lines)
60+
while j < n:
61+
ln = lines[j]
62+
if ln.strip() == "":
63+
# include blank line only if next non-empty line is more-indented than def line
64+
k = j + 1
65+
while k < n and lines[k].strip() == "":
66+
k += 1
67+
if k < n and (len(lines[k]) - len(lines[k].lstrip(" "))) > def_indent:
68+
j += 1
69+
continue
70+
break
71+
indent = len(ln) - len(ln.lstrip(" "))
72+
if indent > def_indent:
73+
j += 1
74+
continue
75+
break
76+
return j
77+
78+
79+
def _extract_def_name(def_line: str) -> Optional[str]:
80+
m = re.search(r"\bdef\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(", def_line)
81+
return m.group(1) if m else None
82+
83+
84+
def _remove_docstring_from_block(block: list[str]) -> tuple[list[str], Optional[list[str]]]:
85+
"""Remove the first docstring in the function body (if present).
86+
Return (cleaned_block, docstring_lines_or_None).
87+
Docstring lines are returned verbatim (including the triple quotes) so they can be reinserted.
88+
Be conservative: if the docstring terminator is missing, we do not remove anything.
89+
"""
90+
def_idx = None
91+
for idx, ln in enumerate(block):
92+
if ln.lstrip().startswith("def "):
93+
def_idx = idx
94+
break
95+
if def_idx is None:
96+
return block, None
97+
98+
i = def_idx + 1
99+
while i < len(block):
100+
ln = block[i]
101+
if '"""' in ln or "'''" in ln:
102+
delim = '"""' if '"""' in ln else "'''"
103+
if ln.count(delim) >= 2:
104+
# single-line docstring
105+
doc_lines = [ln]
106+
new_block = block[:i] + block[i + 1 :]
107+
return new_block, doc_lines
108+
# multi-line: find closing delimiter
109+
doc_lines = [ln]
110+
j = i + 1
111+
while j < len(block) and delim not in block[j]:
112+
doc_lines.append(block[j])
113+
j += 1
114+
if j < len(block):
115+
doc_lines.append(block[j]) # include closing line
116+
new_block = block[:i] + block[j + 1 :]
117+
return new_block, doc_lines
118+
# unterminated docstring -> conservative: don't remove
119+
return block, None
120+
i += 1
121+
return block, None
122+
123+
124+
def _indent_to_string(s: str) -> str:
125+
return s[: len(s) - len(s.lstrip(" "))]
126+
127+
128+
def _fix_overloaded_functions(text: str) -> str:
129+
"""Consolidate consecutive @typing.overload blocks for the same function name.
130+
- For groups with multiple overloads: remove docstrings from earlier overloads and replace them with `...`.
131+
Insert the first found non-empty docstring into the last overload.
132+
- Be robust to various indent and spacing styles.
133+
"""
134+
lines = text.splitlines()
135+
out: list[str] = []
136+
i = 0
137+
n = len(lines)
138+
139+
while i < n:
140+
if lines[i].lstrip().startswith("@typing.overload"):
141+
# attempt to build a group of consecutive overload blocks for the same function
142+
group_blocks: list[list[str]] = []
143+
group_name: Optional[str] = None
144+
j = i
145+
while j < n and lines[j].lstrip().startswith("@typing.overload"):
146+
block_end = _collect_block_boundaries(lines, j)
147+
block = lines[j:block_end]
148+
# find def line in block (if any) to get function name
149+
def_line = None
150+
for ln in block:
151+
if ln.lstrip().startswith("def "):
152+
def_line = ln
153+
break
154+
name = _extract_def_name(def_line) if def_line else None
155+
if group_name is None:
156+
group_name = name
157+
# stop grouping if name differs (do not consume the differing block)
158+
if name != group_name:
159+
break
160+
group_blocks.append(block)
161+
j = block_end
162+
163+
# if single block only, emit as-is
164+
if len(group_blocks) <= 1:
165+
if j > i:
166+
out.extend(lines[i:j])
167+
i = j
168+
else:
169+
out.append(lines[i])
170+
i += 1
171+
continue
172+
173+
# multiple overloads for same function name -> clean docstrings
174+
best_doc: Optional[list[str]] = None
175+
cleaned_blocks: list[list[str]] = []
176+
for blk in group_blocks:
177+
cleaned, doc = _remove_docstring_from_block(blk)
178+
cleaned_blocks.append(cleaned)
179+
if best_doc is None and doc:
180+
# choose the first non-empty docstring
181+
joined = "\n".join(line.strip() for line in doc)
182+
if joined.strip():
183+
best_doc = doc
184+
185+
last_idx = len(cleaned_blocks) - 1
186+
187+
for idx_bl, blk in enumerate(cleaned_blocks):
188+
is_last = idx_bl == last_idx
189+
blk_copy = list(blk)
190+
191+
# drop final @typing.overload on last overload
192+
if is_last:
193+
# remove first decorator if it's exactly '@typing.overload'
194+
for d_idx, ln in enumerate(blk_copy):
195+
if ln.lstrip().startswith("@"):
196+
if ln.lstrip() == "@typing.overload":
197+
blk_copy.pop(d_idx)
198+
break
199+
200+
# find def index in blk_copy
201+
def_idx = None
202+
for k, ln in enumerate(blk_copy):
203+
if ln.lstrip().startswith("def "):
204+
def_idx = k
205+
break
206+
if def_idx is None:
207+
# malformed block: emit as-is
208+
out.extend(blk_copy)
209+
continue
210+
211+
# determine body indentation for inserted lines
212+
body_indent = None
213+
for k in range(def_idx + 1, len(blk_copy)):
214+
if blk_copy[k].strip() != "":
215+
body_indent = _indent_to_string(blk_copy[k])
216+
break
217+
if body_indent is None:
218+
def_line = blk_copy[def_idx]
219+
def_indent = _indent_to_string(def_line)
220+
body_indent = def_indent + " "
221+
222+
if not is_last:
223+
# non-last overloads: ensure single `...` body
224+
new_blk = [*blk_copy[: def_idx + 1], body_indent + "..."]
225+
out.extend(new_blk)
226+
else:
227+
# last overload: if we have best_doc, reinsert it (reindented)
228+
if best_doc:
229+
# compute minimal leading indent in best_doc (so we can re-indent)
230+
minimal = None
231+
for ln in best_doc:
232+
if ln.strip() == "":
233+
continue
234+
lead = len(ln) - len(ln.lstrip(" "))
235+
if minimal is None or lead < minimal:
236+
minimal = lead
237+
if minimal is None:
238+
minimal = 0
239+
stripped_doc = [body_indent + (ln[minimal:]) for ln in best_doc]
240+
rem = blk_copy[def_idx + 1 :]
241+
new_blk = blk_copy[: def_idx + 1] + stripped_doc + rem
242+
out.extend(new_blk)
243+
else:
244+
out.extend(blk_copy)
245+
# advance index past the group
246+
i = j
247+
else:
248+
out.append(lines[i])
249+
i += 1
250+
251+
return "\n".join(out)
252+
253+
254+
# -------------------------
255+
# Typing shadow replacement
256+
# -------------------------
257+
def _fix_typing_shadowing(text: str) -> tuple[str, set, set]:
258+
"""Replace occurrences of `typing.Symbol` with `duckdb_typing.Symbol` when the symbol
259+
is not present in the stdlib `typing` module. Use a negative lookbehind so we don't
260+
touch occurrences already prefixed with `duckdb_typing.`.
261+
Returns (new_text, replaced_set, kept_set).
262+
"""
263+
typing_pattern = re.compile(r"(?<!duckdb_typing\.)\btyping\.([A-Za-z_][A-Za-z0-9_]*)\b")
264+
replaced_typing = set()
265+
kept_typing = set()
266+
267+
def typing_repl(m) -> str:
268+
symbol = m.group(1)
269+
if hasattr(py_typing, symbol):
270+
kept_typing.add(symbol)
271+
return m.group(0)
272+
replaced_typing.add(symbol)
273+
return f"duckdb_typing.{symbol}"
274+
275+
new_text = typing_pattern.sub(typing_repl, text)
276+
return new_text, replaced_typing, kept_typing
277+
278+
279+
# -------------------------
280+
# Optional wrapping (fixed)
281+
# -------------------------
282+
def _fix_optionals(text: str) -> str:
283+
"""Wrap parameter annotations that have a default `= None` in `typing.Optional[...]`
284+
unless the annotation already indicates optional (contains 'None' or 'Optional' or 'Union[..., None]')
285+
or is in the allowlist (Any, object, ClassVar, etc.).
286+
287+
This intentionally identifies the annotation *after* the colon and stops at the character
288+
right before the `=`; it does not try to parse entire function signatures but is much
289+
less greedy than prior attempts.
290+
"""
291+
# Capture colon + annotation (non-greedy) up to the equals; ensure we stop before comma, ), or newline
292+
pattern = re.compile(r"(?P<colon>:\s*)(?P<ann>[^=,\)\n]+?)(?P<trail>\s*)(?=\=\s*None)")
293+
294+
def repl(m: re.Match) -> str:
295+
colon = m.group("colon")
296+
ann_raw = m.group("ann")
297+
trail = m.group("trail") or ""
298+
ann = ann_raw.strip()
299+
300+
if not ann:
301+
return colon + ann_raw + trail
302+
303+
ann_lower = ann.lower()
304+
305+
# Skip if annotation already explicitly optional / contains None (covers `X | None`, `Union[..., None]`, etc.)
306+
if "none" in ann_lower:
307+
return colon + ann_raw + trail
308+
if ("optional[" in ann_lower) or ("typing.optional[" in ann_lower) or ("duckdb_typing.optional[" in ann_lower):
309+
return colon + ann_raw + trail
310+
if ("classvar[" in ann_lower) or ("typing.classvar[" in ann_lower):
311+
return colon + ann_raw + trail
312+
if ann in {"Any", "typing.Any", "duckdb_typing.Any", "object", "ClassVar"}:
313+
return colon + ann_raw + trail
314+
315+
# Otherwise wrap conservatively preserving original whitespace and annotation text
316+
return f"{colon}typing.Optional[{ann}]{trail}"
317+
318+
return pattern.sub(repl, text)
319+
320+
321+
# -------------------------
322+
# Main fixer
323+
# -------------------------
324+
def fix_stub(path: Path):
325+
"""Apply transformations to a .pyi stub file:
326+
1. Normalize/ensure imports
327+
2. Replace tabs with four spaces
328+
3. Consolidate overloaded functions safely
329+
4. Replace shadowed typing symbols with duckdb_typing when appropriate
330+
5. Wrap `= None` defaults in typing.Optional[...] where safe.
331+
"""
332+
print(f"=== Fixing {path}")
333+
text = path.read_text()
334+
335+
text = _fix_imports(text)
336+
337+
text = _fix_exits(text)
338+
339+
# Normalize tabs early for stable indentation handling
340+
text = text.replace("\t", " ")
341+
342+
text = _fix_overloaded_functions(text)
343+
344+
text, replaced_typing, kept_typing = _fix_typing_shadowing(text)
345+
346+
text = _fix_optionals(text)
347+
348+
path.write_text(text)
349+
350+
print(f"[stub fixer] Replaced typing symbols: {sorted(replaced_typing)}")
351+
print(f"[stub fixer] Kept stdlib typing symbols: {sorted(kept_typing)}")
352+
353+
354+
# -------------------------
355+
# Quick demonstration (non-executing example)
356+
# -------------------------
357+
if __name__ == "__main__":
358+
fix_stub(Path("_duckdb-stubs/__init__.pyi"))
359+
fix_stub(Path("_duckdb-stubs/typing.pyi"))
360+
fix_stub(Path("_duckdb-stubs/functional.pyi"))

0 commit comments

Comments
 (0)