diff --git a/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py b/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py index 71f9fdb..6bf9e18 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py +++ b/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py @@ -117,6 +117,7 @@ def char_literals(*literals: str) -> Collection[BindParameter]: column("prosecdef"), column("prokind"), column("provolatile"), + column("pronargs"), column("proargnames"), column("proargmodes"), column("proargtypes"), @@ -318,6 +319,23 @@ def get_types(arg_type_oids): ) +def get_defaults(): + arg_num = ( + func.generate_series(1, pg_proc.c.pronargs) + .table_valued("arg_num", with_ordinality="ordinality") + .alias("arg_num") + ) + return ( + select( + func.array_agg( + func.pg_get_function_arg_default(pg_proc.c.oid, arg_num.c.arg_num) + ) + ) + .select_from(arg_num) + .scalar_subquery() + ) + + def get_procedures_query(version_info): source = get_source_column(version_info) return ( @@ -334,9 +352,7 @@ def get_procedures_query(version_info): func.coalesce( get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes) ).label("arg_types"), - func.pg_get_expr( - pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS) - ).label("arg_defaults"), + get_defaults().label("arg_defaults"), ) .select_from( pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid) @@ -369,9 +385,7 @@ def get_functions_query(version_info): func.coalesce( get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes) ).label("arg_types"), - func.pg_get_expr( - pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS) - ).label("arg_defaults"), + get_defaults().label("arg_defaults"), ) .select_from( pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid) diff --git a/tests/dialect/postgresql/test_function_defaults.py b/tests/dialect/postgresql/test_function_defaults.py new file mode 100644 index 0000000..75bef84 --- /dev/null +++ b/tests/dialect/postgresql/test_function_defaults.py @@ -0,0 +1,45 @@ +import pytest +from pytest_mock_resources import PostgresConfig, create_postgres_fixture +from sqlalchemy import text + +from sqlalchemy_declarative_extensions import Functions +from sqlalchemy_declarative_extensions.dialects.postgresql import ( + Function, + FunctionParam, + FunctionVolatility, +) +from sqlalchemy_declarative_extensions.function.compare import compare_functions + +pg = create_postgres_fixture(scope="function", engine_kwargs={"echo": True}) + + + + +@pytest.mark.parametrize( + ("default_a", "default_b", "default_c"), + [(None, None, "''::text"), (1, 0, "'m'::text"), (None, 2, "'ft'::text")], +) +def test_function_argument_defaults(pg, default_a, default_b, default_c): + add_label_function = Function( + name="add_label", + definition=""" + BEGIN + RETURN (((a + b))::text || c); + END; + """, + parameters=[ + FunctionParam("a", "integer", default=default_a), + FunctionParam("b", "integer", default=default_b), + FunctionParam("c", "text", default=default_c), + ], + returns="TEXT", + volatility=FunctionVolatility.STABLE, + language="plpgsql", + ).normalize() + create_function = add_label_function.to_sql_create() + functions = Functions([add_label_function]) + with pg.connect() as connection: + connection.execute(text("\n".join(create_function))) + diff = compare_functions(connection, functions) + for op in diff: + assert op.from_function == op.function