Skip to content

Commit 1000cec

Browse files
committed
Use correct calls from adbc to get table schema
1 parent 5b43018 commit 1000cec

File tree

1 file changed

+57
-35
lines changed

1 file changed

+57
-35
lines changed

cosmotech/coal/utils/postgresql.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
# etc., to any person is prohibited unless it has been previously and
66
# specifically authorized by written means by Cosmo Tech.
77

8-
from adbc_driver_postgresql import dbapi
9-
from pyarrow import Table
10-
import pyarrow as pa
118
from typing import Optional
129
from urllib.parse import quote
10+
11+
import pyarrow as pa
12+
from adbc_driver_postgresql import dbapi
13+
from pyarrow import Table
14+
1315
from cosmotech.coal.utils.logger import LOGGER
1416

1517

@@ -39,25 +41,45 @@ def get_postgresql_table_schema(
3941
postgres_user: str,
4042
postgres_password: str,
4143
) -> Optional[pa.Schema]:
42-
postgresql_full_uri = generate_postgresql_full_uri(postgres_host,
43-
postgres_port,
44-
postgres_db,
45-
postgres_user,
46-
postgres_password)
44+
"""
45+
Get the schema of an existing PostgreSQL table using SQL queries.
4746
48-
with dbapi.connect(postgresql_full_uri) as conn:
49-
with conn.cursor() as curs:
50-
# Get table metadata using ADBC's get_objects
51-
table_info = curs.get_objects(
52-
table_name=target_table_name,
53-
schema_name=postgres_schema
54-
)
55-
56-
# If table exists, return its schema
57-
if table_info is not None and len(table_info) > 0:
58-
return table_info[0].schema
59-
60-
return None
47+
Args:
48+
target_table_name: Name of the table
49+
postgres_host: PostgreSQL host
50+
postgres_port: PostgreSQL port
51+
postgres_db: PostgreSQL database name
52+
postgres_schema: PostgreSQL schema name
53+
postgres_user: PostgreSQL username
54+
postgres_password: PostgreSQL password
55+
56+
Returns:
57+
PyArrow Schema if table exists, None otherwise
58+
"""
59+
LOGGER.debug(f"Getting schema for table {postgres_schema}.{target_table_name}")
60+
61+
postgresql_full_uri = generate_postgresql_full_uri(postgres_host,
62+
postgres_port,
63+
postgres_db,
64+
postgres_user,
65+
postgres_password)
66+
67+
with (dbapi.connect(postgresql_full_uri) as conn):
68+
try:
69+
catalog = conn.adbc_get_objects(depth="tables",
70+
catalog_filter=postgres_db,
71+
db_schema_filter=postgres_schema,
72+
table_name_filter=target_table_name).read_all().to_pylist()[0]
73+
schema = catalog["catalog_db_schemas"][0]
74+
table = schema["db_schema_tables"][0]
75+
if table["table_name"] == target_table_name:
76+
return conn.adbc_get_table_schema(
77+
target_table_name,
78+
db_schema_filter=postgres_schema,
79+
)
80+
except IndexError:
81+
LOGGER.warning(f"Table {postgres_schema}.{target_table_name} not found")
82+
return None
6183

6284

6385
def adapt_table_to_schema(
@@ -70,23 +92,23 @@ def adapt_table_to_schema(
7092
LOGGER.debug(f"Starting schema adaptation for table with {len(data)} rows")
7193
LOGGER.debug(f"Original schema: {data.schema}")
7294
LOGGER.debug(f"Target schema: {target_schema}")
73-
95+
7496
target_fields = {field.name: field.type for field in target_schema}
7597
new_columns = []
76-
98+
7799
# Track adaptations for summary
78100
added_columns = []
79101
dropped_columns = []
80102
type_conversions = []
81103
failed_conversions = []
82-
104+
83105
# Process each field in target schema
84106
for field_name, target_type in target_fields.items():
85107
if field_name in data.column_names:
86108
# Column exists - try to cast to target type
87109
col = data[field_name]
88110
original_type = col.type
89-
111+
90112
if original_type != target_type:
91113
LOGGER.debug(
92114
f"Attempting to cast column '{field_name}' "
@@ -115,23 +137,23 @@ def adapt_table_to_schema(
115137
LOGGER.debug(f"Adding missing column '{field_name}' with null values")
116138
new_columns.append(pa.nulls(len(data), type=target_type))
117139
added_columns.append(field_name)
118-
140+
119141
# Log columns that will be dropped
120142
dropped_columns = [
121-
name for name in data.column_names
143+
name for name in data.column_names
122144
if name not in target_fields
123145
]
124146
if dropped_columns:
125147
LOGGER.debug(
126148
f"Dropping extra columns not in target schema: {dropped_columns}"
127149
)
128-
150+
129151
# Create new table
130152
adapted_table = pa.Table.from_arrays(
131153
new_columns,
132154
schema=target_schema
133155
)
134-
156+
135157
# Log summary of adaptations
136158
LOGGER.debug("Schema adaptation summary:")
137159
if added_columns:
@@ -144,7 +166,7 @@ def adapt_table_to_schema(
144166
LOGGER.debug(
145167
f"- Failed conversions (filled with nulls): {failed_conversions}"
146168
)
147-
169+
148170
LOGGER.debug(f"Final adapted table schema: {adapted_table.schema}")
149171
return adapted_table
150172

@@ -164,7 +186,7 @@ def send_pyarrow_table_to_postgresql(
164186
f"Preparing to send data to PostgreSQL table '{postgres_schema}.{target_table_name}'"
165187
)
166188
LOGGER.debug(f"Input table has {len(data)} rows")
167-
189+
168190
# Get existing schema if table exists
169191
existing_schema = get_postgresql_table_schema(
170192
target_table_name,
@@ -175,7 +197,7 @@ def send_pyarrow_table_to_postgresql(
175197
postgres_user,
176198
postgres_password
177199
)
178-
200+
179201
if existing_schema is not None:
180202
LOGGER.debug(f"Found existing table with schema: {existing_schema}")
181203
if not replace:
@@ -185,7 +207,7 @@ def send_pyarrow_table_to_postgresql(
185207
LOGGER.debug("Replace mode enabled - skipping schema adaptation")
186208
else:
187209
LOGGER.debug("No existing table found - will create new table")
188-
210+
189211
# Proceed with ingestion
190212
total = 0
191213
postgresql_full_uri = generate_postgresql_full_uri(
@@ -195,7 +217,7 @@ def send_pyarrow_table_to_postgresql(
195217
postgres_user,
196218
postgres_password
197219
)
198-
220+
199221
LOGGER.debug("Connecting to PostgreSQL database")
200222
with dbapi.connect(postgresql_full_uri, autocommit=True) as conn:
201223
with conn.cursor() as curs:
@@ -208,6 +230,6 @@ def send_pyarrow_table_to_postgresql(
208230
"replace" if replace else "create_append",
209231
db_schema_name=postgres_schema
210232
)
211-
233+
212234
LOGGER.debug(f"Successfully ingested {total} rows")
213235
return total

0 commit comments

Comments
 (0)