Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion c/driver/postgresql/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,15 @@ static const char* kSchemaQueryAll =
// Parameterized on schema_name, relkind
// Note that when binding relkind as a string it must look like {"r", "v", ...}
// (i.e., double quotes). Binding a binary list<string> element also works.
// Don't use pg_table_is_visible(): it is search_path-dependent and would hide tables
// in non-current schemas even when GetObjects is called with a schema filter.
static const char* kTablesQueryAll =
"SELECT c.relname, CASE c.relkind WHEN 'r' THEN 'table' WHEN 'v' THEN 'view' "
"WHEN 'm' THEN 'materialized view' WHEN 't' THEN 'TOAST table' "
"WHEN 'f' THEN 'foreign table' WHEN 'p' THEN 'partitioned table' END "
"AS reltype FROM pg_catalog.pg_class c "
"LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace "
"WHERE pg_catalog.pg_table_is_visible(c.oid) AND n.nspname = $1 AND c.relkind = "
"WHERE n.nspname = $1 AND c.relkind = "
"ANY($2)";

// Parameterized on schema_name, table_name
Expand Down
55 changes: 55 additions & 0 deletions c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,61 @@ TEST_F(PostgresConnectionTest, GetObjectsGetDbSchemas) {
ASSERT_NE(schema, nullptr) << "schema public not found";
}

TEST_F(PostgresConnectionTest, GetObjectsSchemaFilterFindsTablesOutsideSearchPath) {
ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error));

const std::string schema_name = "adbc_get_objects_test";
const std::string table_name = "schema_filter_table";

// Ensure the schema is not part of the current search_path.
ASSERT_THAT(
AdbcConnectionSetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA,
"public", &error),
IsOkStatus(&error));

ASSERT_THAT(quirks()->EnsureDbSchema(&connection, schema_name, &error),
IsOkStatus(&error));
ASSERT_THAT(quirks()->DropTable(&connection, table_name, schema_name, &error),
IsOkStatus(&error));

{
adbc_validation::Handle<struct AdbcStatement> statement;
ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
IsOkStatus(&error));

std::string create =
"CREATE TABLE \"" + schema_name + "\".\"" + table_name + "\" (ints INT)";
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, create.c_str(), &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error),
IsOkStatus(&error));
}

adbc_validation::StreamReader reader;
ASSERT_THAT(AdbcConnectionGetObjects(&connection, ADBC_OBJECT_DEPTH_TABLES, nullptr,
schema_name.c_str(), nullptr, nullptr, nullptr,
&reader.stream.value, &error),
IsOkStatus(&error));
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_NE(nullptr, reader.array->release);
ASSERT_GT(reader.array->length, 0);

auto get_objects_data = adbc_validation::GetObjectsReader{&reader.array_view.value};
ASSERT_NE(*get_objects_data, nullptr)
<< "could not initialize the AdbcGetObjectsData object";

const auto catalog = adbc_validation::ConnectionGetOption(
&connection, ADBC_CONNECTION_OPTION_CURRENT_CATALOG, &error);
ASSERT_TRUE(catalog.has_value());

struct AdbcGetObjectsTable* table = InternalAdbcGetObjectsDataGetTableByName(
*get_objects_data, catalog->c_str(), schema_name.c_str(), table_name.c_str());
ASSERT_NE(table, nullptr) << "could not find " << schema_name << "." << table_name
<< " via GetObjects";
}

TEST_F(PostgresConnectionTest, GetObjectsGetAllFindsPrimaryKey) {
ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error));
Expand Down
41 changes: 40 additions & 1 deletion python/adbc_driver_postgresql/tests/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import pyarrow
import pyarrow.dataset
import pytest

from adbc_driver_postgresql import ConnectionOptions, StatementOptions, dbapi


Expand Down Expand Up @@ -53,6 +52,46 @@ def test_conn_change_db_schema(postgres: dbapi.Connection) -> None:
assert postgres.adbc_current_db_schema == "dbapischema"


def test_get_objects_schema_filter_outside_search_path(
postgres: dbapi.Connection,
) -> None:
schema_name = "dbapi_get_objects_test"
table_name = "schema_filter_table"

# Regression test: adbc_get_objects(db_schema_filter=...) should not depend on the
# connection's current schema/search_path.
assert postgres.adbc_current_db_schema == "public"

with postgres.cursor() as cur:
cur.execute(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"')
cur.execute(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"')
cur.execute(f'CREATE TABLE "{schema_name}"."{table_name}" (ints INT)')
postgres.commit()

assert postgres.adbc_current_db_schema == "public"

metadata = (
postgres.adbc_get_objects(
depth="tables",
db_schema_filter=schema_name,
table_name_filter=table_name,
)
.read_all()
.to_pylist()
)

catalog_name = postgres.adbc_current_catalog
catalog = next((row for row in metadata if row["catalog_name"] == catalog_name), None)
assert catalog is not None

schemas = catalog["catalog_db_schemas"]
assert len(schemas) == 1
assert schemas[0]["db_schema_name"] == schema_name
tables = schemas[0]["db_schema_tables"]
assert len(tables) == 1
assert tables[0]["table_name"] == table_name


def test_conn_get_info(postgres: dbapi.Connection) -> None:
info = postgres.adbc_get_info()
assert info["driver_name"] == "ADBC PostgreSQL Driver"
Expand Down
Loading