diff --git a/c/driver/postgresql/connection.cc b/c/driver/postgresql/connection.cc index cf4dea1577..7738a44fa1 100644 --- a/c/driver/postgresql/connection.cc +++ b/c/driver/postgresql/connection.cc @@ -72,13 +72,15 @@ static const char* kSchemaQueryAll = // Parameterized on schema_name, relkind // Note that when binding relkind as a string it must look like {"r", "v", ...} // (i.e., double quotes). Binding a binary list element also works. +// Don't use pg_table_is_visible(): it is search_path-dependent and would hide tables +// in non-current schemas even when GetObjects is called with a schema filter. static const char* kTablesQueryAll = "SELECT c.relname, CASE c.relkind WHEN 'r' THEN 'table' WHEN 'v' THEN 'view' " "WHEN 'm' THEN 'materialized view' WHEN 't' THEN 'TOAST table' " "WHEN 'f' THEN 'foreign table' WHEN 'p' THEN 'partitioned table' END " "AS reltype FROM pg_catalog.pg_class c " "LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace " - "WHERE pg_catalog.pg_table_is_visible(c.oid) AND n.nspname = $1 AND c.relkind = " + "WHERE n.nspname = $1 AND c.relkind = " "ANY($2)"; // Parameterized on schema_name, table_name diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index 5eca504da3..d457a14a00 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -372,6 +372,61 @@ TEST_F(PostgresConnectionTest, GetObjectsGetDbSchemas) { ASSERT_NE(schema, nullptr) << "schema public not found"; } +TEST_F(PostgresConnectionTest, GetObjectsSchemaFilterFindsTablesOutsideSearchPath) { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); + + const std::string schema_name = "adbc_get_objects_test"; + const std::string table_name = "schema_filter_table"; + + // Ensure the schema is not part of the current search_path. + ASSERT_THAT( + AdbcConnectionSetOption(&connection, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA, + "public", &error), + IsOkStatus(&error)); + + ASSERT_THAT(quirks()->EnsureDbSchema(&connection, schema_name, &error), + IsOkStatus(&error)); + ASSERT_THAT(quirks()->DropTable(&connection, table_name, schema_name, &error), + IsOkStatus(&error)); + + { + adbc_validation::Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + + std::string create = + "CREATE TABLE \"" + schema_name + "\".\"" + table_name + "\" (ints INT)"; + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, create.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + } + + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcConnectionGetObjects(&connection, ADBC_OBJECT_DEPTH_TABLES, nullptr, + schema_name.c_str(), nullptr, nullptr, nullptr, + &reader.stream.value, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_GT(reader.array->length, 0); + + auto get_objects_data = adbc_validation::GetObjectsReader{&reader.array_view.value}; + ASSERT_NE(*get_objects_data, nullptr) + << "could not initialize the AdbcGetObjectsData object"; + + const auto catalog = adbc_validation::ConnectionGetOption( + &connection, ADBC_CONNECTION_OPTION_CURRENT_CATALOG, &error); + ASSERT_TRUE(catalog.has_value()); + + struct AdbcGetObjectsTable* table = InternalAdbcGetObjectsDataGetTableByName( + *get_objects_data, catalog->c_str(), schema_name.c_str(), table_name.c_str()); + ASSERT_NE(table, nullptr) << "could not find " << schema_name << "." << table_name + << " via GetObjects"; +} + TEST_F(PostgresConnectionTest, GetObjectsGetAllFindsPrimaryKey) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error)); diff --git a/python/adbc_driver_postgresql/tests/test_dbapi.py b/python/adbc_driver_postgresql/tests/test_dbapi.py index 0b25ef9238..5fe72cfa37 100644 --- a/python/adbc_driver_postgresql/tests/test_dbapi.py +++ b/python/adbc_driver_postgresql/tests/test_dbapi.py @@ -53,6 +53,48 @@ def test_conn_change_db_schema(postgres: dbapi.Connection) -> None: assert postgres.adbc_current_db_schema == "dbapischema" +def test_get_objects_schema_filter_outside_search_path( + postgres: dbapi.Connection, +) -> None: + schema_name = "dbapi_get_objects_test" + table_name = "schema_filter_table" + + # Regression test: adbc_get_objects(db_schema_filter=...) should not depend on the + # connection's current schema/search_path. + assert postgres.adbc_current_db_schema == "public" + + with postgres.cursor() as cur: + cur.execute(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"') + cur.execute(f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}"') + cur.execute(f'CREATE TABLE "{schema_name}"."{table_name}" (ints INT)') + postgres.commit() + + assert postgres.adbc_current_db_schema == "public" + + metadata = ( + postgres.adbc_get_objects( + depth="tables", + db_schema_filter=schema_name, + table_name_filter=table_name, + ) + .read_all() + .to_pylist() + ) + + catalog_name = postgres.adbc_current_catalog + catalog = next( + (row for row in metadata if row["catalog_name"] == catalog_name), None + ) + assert catalog is not None + + schemas = catalog["catalog_db_schemas"] + assert len(schemas) == 1 + assert schemas[0]["db_schema_name"] == schema_name + tables = schemas[0]["db_schema_tables"] + assert len(tables) == 1 + assert tables[0]["table_name"] == table_name + + def test_conn_get_info(postgres: dbapi.Connection) -> None: info = postgres.adbc_get_info() assert info["driver_name"] == "ADBC PostgreSQL Driver"