Skip to content

Commit 3890939

Browse files
committed
Stubfix fix
1 parent 24c2711 commit 3890939

File tree

2 files changed

+37
-87
lines changed

2 files changed

+37
-87
lines changed

duckdb/_duckdb/typing.pyi

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import typing
1+
import typing as pytyping
22

33
__all__: list[str] = [
44
"BIGINT",
@@ -32,19 +32,19 @@ __all__: list[str] = [
3232
]
3333

3434
class DuckDBPyType:
35-
@typing.overload
35+
@pytyping.overload
3636
def __eq__(self, other: DuckDBPyType) -> bool: ...
3737
def __eq__(self, other: str) -> bool: ...
3838
def __getattr__(self, name: str) -> DuckDBPyType: ...
3939
def __getitem__(self, name: str) -> DuckDBPyType: ...
4040
def __hash__(self) -> int: ...
41-
@typing.overload
41+
@pytyping.overload
4242
def __init__(self, type_str: str, connection: ...) -> None: ...
43-
@typing.overload
43+
@pytyping.overload
4444
def __init__(self, arg0: ...) -> None: ...
45-
@typing.overload
45+
@pytyping.overload
4646
def __init__(self, arg0: ...) -> None: ...
47-
def __init__(self, obj: typing.Any) -> None: ...
47+
def __init__(self, obj: pytyping.Any) -> None: ...
4848
@property
4949
def children(self) -> list: ...
5050
@property

scripts/stubfix.py

Lines changed: 31 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -65,46 +65,6 @@ def _extract_def_name(def_line: str) -> str | None:
6565
return m.group(1) if m else None
6666

6767

68-
def _remove_docstring_from_block(block: list[str]) -> tuple[list[str], list[str] | None]:
69-
"""Remove the first docstring in the function body (if present).
70-
Return (cleaned_block, docstring_lines_or_None).
71-
Docstring lines are returned verbatim (including the triple quotes) so they can be reinserted.
72-
Be conservative: if the docstring terminator is missing, we do not remove anything.
73-
"""
74-
def_idx = None
75-
for idx, ln in enumerate(block):
76-
if ln.lstrip().startswith("def "):
77-
def_idx = idx
78-
break
79-
if def_idx is None:
80-
return block, None
81-
82-
i = def_idx + 1
83-
while i < len(block):
84-
ln = block[i]
85-
if '"""' in ln or "'''" in ln:
86-
delim = '"""' if '"""' in ln else "'''"
87-
if ln.count(delim) >= 2:
88-
# single-line docstring
89-
doc_lines = [ln]
90-
new_block = block[:i] + block[i + 1 :]
91-
return new_block, doc_lines
92-
# multi-line: find closing delimiter
93-
doc_lines = [ln]
94-
j = i + 1
95-
while j < len(block) and delim not in block[j]:
96-
doc_lines.append(block[j])
97-
j += 1
98-
if j < len(block):
99-
doc_lines.append(block[j]) # include closing line
100-
new_block = block[:i] + block[j + 1 :]
101-
return new_block, doc_lines
102-
# unterminated docstring -> conservative: don't remove
103-
return block, None
104-
i += 1
105-
return block, None
106-
107-
10868
def _indent_to_string(s: str) -> str:
10969
return s[: len(s) - len(s.lstrip(" "))]
11070

@@ -144,6 +104,20 @@ def _fix_overloaded_functions(text: str) -> str:
144104
group_blocks.append(block)
145105
j = block_end
146106

107+
# Check if there's a non-overloaded function immediately following
108+
# that has the same name (this would be the actual implementation)
109+
has_implementation = False
110+
if j < n and not lines[j].lstrip().startswith("@typing.overload"):
111+
# Look for a def line in the next block
112+
impl_block_end = _collect_block_boundaries(lines, j)
113+
impl_block = lines[j:impl_block_end]
114+
for ln in impl_block:
115+
if ln.lstrip().startswith("def "):
116+
impl_name = _extract_def_name(ln)
117+
if impl_name == group_name:
118+
has_implementation = True
119+
break
120+
147121
# if single block only, emit as-is
148122
if len(group_blocks) <= 1:
149123
if j > i:
@@ -154,26 +128,15 @@ def _fix_overloaded_functions(text: str) -> str:
154128
i += 1
155129
continue
156130

157-
# multiple overloads for same function name -> clean docstrings
158-
best_doc: list[str] | None = None
159-
cleaned_blocks: list[list[str]] = []
160-
for blk in group_blocks:
161-
cleaned, doc = _remove_docstring_from_block(blk)
162-
cleaned_blocks.append(cleaned)
163-
if best_doc is None and doc:
164-
# choose the first non-empty docstring
165-
joined = "\n".join(line.strip() for line in doc)
166-
if joined.strip():
167-
best_doc = doc
168-
169-
last_idx = len(cleaned_blocks) - 1
170-
171-
for idx_bl, blk in enumerate(cleaned_blocks):
131+
# Process multiple overload blocks
132+
last_idx = len(group_blocks) - 1
133+
134+
for idx_bl, blk in enumerate(group_blocks):
172135
is_last = idx_bl == last_idx
173136
blk_copy = list(blk)
174137

175-
# drop final @typing.overload on last overload
176-
if is_last:
138+
# Only drop @typing.overload on the last overload if there's no separate implementation
139+
if is_last and not has_implementation:
177140
# remove first decorator if it's exactly '@typing.overload'
178141
for d_idx, ln in enumerate(blk_copy):
179142
if ln.lstrip().startswith("@"):
@@ -203,29 +166,14 @@ def _fix_overloaded_functions(text: str) -> str:
203166
def_indent = _indent_to_string(def_line)
204167
body_indent = def_indent + " "
205168

206-
if not is_last:
207-
# non-last overloads: ensure single `...` body
208-
new_blk = [*blk_copy[: def_idx + 1], body_indent + "..."]
209-
out.extend(new_blk)
169+
if not is_last or has_implementation:
170+
# For non-last blocks or when there's a separate implementation,
171+
# just emit the signature
172+
out.extend([*blk_copy[: def_idx + 1]])
210173
else:
211-
# last overload: if we have best_doc, reinsert it (reindented)
212-
if best_doc:
213-
# compute minimal leading indent in best_doc (so we can re-indent)
214-
minimal = None
215-
for ln in best_doc:
216-
if ln.strip() == "":
217-
continue
218-
lead = len(ln) - len(ln.lstrip(" "))
219-
if minimal is None or lead < minimal:
220-
minimal = lead
221-
if minimal is None:
222-
minimal = 0
223-
stripped_doc = [body_indent + (ln[minimal:]) for ln in best_doc]
224-
rem = blk_copy[def_idx + 1 :]
225-
new_blk = blk_copy[: def_idx + 1] + stripped_doc + rem
226-
out.extend(new_blk)
227-
else:
228-
out.extend(blk_copy)
174+
# This is the last block and there's no separate implementation
175+
out.extend(blk_copy)
176+
229177
# advance index past the group
230178
i = j
231179
else:
@@ -332,8 +280,10 @@ def fix_stub(path: Path):
332280

333281
path.write_text(text)
334282

335-
print(f"[stub fixer] Replaced typing symbols: {sorted(replaced_typing)}")
336-
print(f"[stub fixer] Kept stdlib typing symbols: {sorted(kept_typing)}")
283+
if replaced_typing:
284+
print(f"[stub fixer] Replaced stdlib typing symbols: {sorted(replaced_typing)}")
285+
if kept_typing:
286+
print(f"[stub fixer] Kept duckdb typing symbols: {sorted(kept_typing)}")
337287

338288

339289
def _is_valid_stubfile(path: Path) -> bool:

0 commit comments

Comments
 (0)