Skip to content

Commit 83d3f94

Browse files
This hopefully fixes duckdb#26: - Changed the type of the file_globs parameter of from_parquet and read_parquet in connection_methods.json. - Added the generation of @overload functions in the generation wrappers python code.
2 parents a3ab403 + 7866f31 commit 83d3f94

File tree

3 files changed

+22
-20
lines changed

3 files changed

+22
-20
lines changed

scripts/connection_methods.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@
412412
"fetch_record_batch",
413413
"arrow"
414414
],
415-
415+
416416
"function": "FetchRecordBatchReader",
417417
"docs": "Fetch an Arrow RecordBatchReader following execute()",
418418
"args": [
@@ -992,7 +992,7 @@
992992
"args": [
993993
{
994994
"name": "file_globs",
995-
"type": "str"
995+
"type": "List[str]"
996996
},
997997
{
998998
"name": "binary_as_string",

scripts/generate_connection_stubs.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,12 @@ def create_arguments(arguments) -> list:
5151
result.append(argument)
5252
return result
5353

54-
def create_definition(name, method) -> str:
55-
definition = f"def {name}("
54+
def create_definition(name, method, overloaded: bool) -> str:
55+
if overloaded:
56+
definition: str = "@overload\n"
57+
else:
58+
definition: str = ""
59+
definition += f"def {name}("
5660
arguments = ['self']
5761
if 'args' in method:
5862
arguments.extend(create_arguments(method['args']))
@@ -65,20 +69,17 @@ def create_definition(name, method) -> str:
6569
definition += f" -> {method['return']}: ..."
6670
return definition
6771

68-
# We have "duplicate" methods, which are overloaded
69-
# maybe we should add @overload to these instead, but this is easier
70-
written_methods = set()
72+
# We have "duplicate" methods, which are overloaded.
73+
# We keep note of them to add the @overload decorator.
74+
overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)}
7175

7276
for method in connection_methods:
7377
if isinstance(method['name'], list):
7478
names = method['name']
7579
else:
7680
names = [method['name']]
7781
for name in names:
78-
if name in written_methods:
79-
continue
80-
body.append(create_definition(name, method))
81-
written_methods.add(name)
82+
body.append(create_definition(name, method, name in overloaded_methods))
8283

8384
# ---- End of generation code ----
8485

scripts/generate_connection_wrapper_stubs.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,12 @@ def create_arguments(arguments) -> list:
6666
result.append(argument)
6767
return result
6868

69-
def create_definition(name, method) -> str:
70-
definition = f"def {name}("
69+
def create_definition(name, method, overloaded: bool) -> str:
70+
if overloaded:
71+
definition: str = "@overload\n"
72+
else:
73+
definition: str = ""
74+
definition += f"def {name}("
7175
arguments = []
7276
if name in SPECIAL_METHOD_NAMES:
7377
arguments.append('df: pandas.DataFrame')
@@ -82,9 +86,9 @@ def create_definition(name, method) -> str:
8286
definition += f" -> {method['return']}: ..."
8387
return definition
8488

85-
# We have "duplicate" methods, which are overloaded
86-
# maybe we should add @overload to these instead, but this is easier
87-
written_methods = set()
89+
# We have "duplicate" methods, which are overloaded.
90+
# We keep note of them to add the @overload decorator.
91+
overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)}
8892

8993
body = []
9094
for method in methods:
@@ -99,10 +103,7 @@ def create_definition(name, method) -> str:
99103
method['kwargs'].append({'name': 'connection', 'type': 'DuckDBPyConnection', 'default': '...'})
100104

101105
for name in names:
102-
if name in written_methods:
103-
continue
104-
body.append(create_definition(name, method))
105-
written_methods.add(name)
106+
body.append(create_definition(name, method, name in overloaded_methods))
106107

107108
# ---- End of generation code ----
108109

0 commit comments

Comments
 (0)