Skip to content

Commit eea161c

Browse files
authored
Merge pull request #132 from Dan-Knott/dk/fix-function-arg-defaults
Fix PostgreSQL function argument default parsing
2 parents d8ab535 + a389397 commit eea161c

File tree

2 files changed

+65
-6
lines changed

2 files changed

+65
-6
lines changed

src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def char_literals(*literals: str) -> Collection[BindParameter]:
117117
column("prosecdef"),
118118
column("prokind"),
119119
column("provolatile"),
120+
column("pronargs"),
120121
column("proargnames"),
121122
column("proargmodes"),
122123
column("proargtypes"),
@@ -318,6 +319,23 @@ def get_types(arg_type_oids):
318319
)
319320

320321

322+
def get_defaults():
323+
arg_num = (
324+
func.generate_series(1, pg_proc.c.pronargs)
325+
.table_valued("arg_num", with_ordinality="ordinality")
326+
.alias("arg_num")
327+
)
328+
return (
329+
select(
330+
func.array_agg(
331+
func.pg_get_function_arg_default(pg_proc.c.oid, arg_num.c.arg_num)
332+
)
333+
)
334+
.select_from(arg_num)
335+
.scalar_subquery()
336+
)
337+
338+
321339
def get_procedures_query(version_info):
322340
source = get_source_column(version_info)
323341
return (
@@ -334,9 +352,7 @@ def get_procedures_query(version_info):
334352
func.coalesce(
335353
get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes)
336354
).label("arg_types"),
337-
func.pg_get_expr(
338-
pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS)
339-
).label("arg_defaults"),
355+
get_defaults().label("arg_defaults"),
340356
)
341357
.select_from(
342358
pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
@@ -369,9 +385,7 @@ def get_functions_query(version_info):
369385
func.coalesce(
370386
get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes)
371387
).label("arg_types"),
372-
func.pg_get_expr(
373-
pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS)
374-
).label("arg_defaults"),
388+
get_defaults().label("arg_defaults"),
375389
)
376390
.select_from(
377391
pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import pytest
2+
from pytest_mock_resources import PostgresConfig, create_postgres_fixture
3+
from sqlalchemy import text
4+
5+
from sqlalchemy_declarative_extensions import Functions
6+
from sqlalchemy_declarative_extensions.dialects.postgresql import (
7+
Function,
8+
FunctionParam,
9+
FunctionVolatility,
10+
)
11+
from sqlalchemy_declarative_extensions.function.compare import compare_functions
12+
13+
pg = create_postgres_fixture(scope="function", engine_kwargs={"echo": True})
14+
15+
16+
17+
18+
@pytest.mark.parametrize(
19+
("default_a", "default_b", "default_c"),
20+
[(None, None, "''::text"), (1, 0, "'m'::text"), (None, 2, "'ft'::text")],
21+
)
22+
def test_function_argument_defaults(pg, default_a, default_b, default_c):
23+
add_label_function = Function(
24+
name="add_label",
25+
definition="""
26+
BEGIN
27+
RETURN (((a + b))::text || c);
28+
END;
29+
""",
30+
parameters=[
31+
FunctionParam("a", "integer", default=default_a),
32+
FunctionParam("b", "integer", default=default_b),
33+
FunctionParam("c", "text", default=default_c),
34+
],
35+
returns="TEXT",
36+
volatility=FunctionVolatility.STABLE,
37+
language="plpgsql",
38+
).normalize()
39+
create_function = add_label_function.to_sql_create()
40+
functions = Functions([add_label_function])
41+
with pg.connect() as connection:
42+
connection.execute(text("\n".join(create_function)))
43+
diff = compare_functions(connection, functions)
44+
for op in diff:
45+
assert op.from_function == op.function

0 commit comments

Comments
 (0)