Skip to content

Commit 450073e

Browse files
committed
Add lookup for existing schema in target postgresql
If table exists and a schema is found return use it to enforce type conversion of sended data before
1 parent 1b160dc commit 450073e

File tree

1 file changed

+165
-8
lines changed

1 file changed

+165
-8
lines changed

cosmotech/coal/utils/postgresql.py

Lines changed: 165 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
from adbc_driver_postgresql import dbapi
99
from pyarrow import Table
10+
import pyarrow as pa
11+
from typing import Optional
12+
from cosmotech.coal.utils.logger import LOGGER
1013

1114

1215
def generate_postgresql_full_uri(
@@ -23,6 +26,125 @@ def generate_postgresql_full_uri(
2326
f'/{postgres_db}')
2427

2528

29+
def get_postgresql_table_schema(
30+
target_table_name: str,
31+
postgres_host: str,
32+
postgres_port: str,
33+
postgres_db: str,
34+
postgres_schema: str,
35+
postgres_user: str,
36+
postgres_password: str,
37+
) -> Optional[pa.Schema]:
38+
postgresql_full_uri = generate_postgresql_full_uri(postgres_host,
39+
postgres_port,
40+
postgres_db,
41+
postgres_user,
42+
postgres_password)
43+
44+
with dbapi.connect(postgresql_full_uri) as conn:
45+
with conn.cursor() as curs:
46+
# Get table metadata using ADBC's get_objects
47+
table_info = curs.get_objects(
48+
table_name=target_table_name,
49+
schema_name=postgres_schema
50+
)
51+
52+
# If table exists, return its schema
53+
if table_info is not None and len(table_info) > 0:
54+
return table_info[0].schema
55+
56+
return None
57+
58+
59+
def adapt_table_to_schema(
60+
data: pa.Table,
61+
target_schema: pa.Schema
62+
) -> pa.Table:
63+
"""
64+
Adapt a PyArrow table to match a target schema with detailed logging.
65+
"""
66+
LOGGER.debug(f"Starting schema adaptation for table with {len(data)} rows")
67+
LOGGER.debug(f"Original schema: {data.schema}")
68+
LOGGER.debug(f"Target schema: {target_schema}")
69+
70+
target_fields = {field.name: field.type for field in target_schema}
71+
new_columns = []
72+
73+
# Track adaptations for summary
74+
added_columns = []
75+
dropped_columns = []
76+
type_conversions = []
77+
failed_conversions = []
78+
79+
# Process each field in target schema
80+
for field_name, target_type in target_fields.items():
81+
if field_name in data.column_names:
82+
# Column exists - try to cast to target type
83+
col = data[field_name]
84+
original_type = col.type
85+
86+
if original_type != target_type:
87+
LOGGER.debug(
88+
f"Attempting to cast column '{field_name}' "
89+
f"from {original_type} to {target_type}"
90+
)
91+
try:
92+
new_col = pa.compute.cast(col, target_type)
93+
new_columns.append(new_col)
94+
type_conversions.append(
95+
f"{field_name}: {original_type} -> {target_type}"
96+
)
97+
except pa.ArrowInvalid as e:
98+
LOGGER.warning(
99+
f"Failed to cast column '{field_name}' "
100+
f"from {original_type} to {target_type}. "
101+
f"Filling with nulls. Error: {str(e)}"
102+
)
103+
new_columns.append(pa.nulls(len(data), type=target_type))
104+
failed_conversions.append(
105+
f"{field_name}: {original_type} -> {target_type}"
106+
)
107+
else:
108+
new_columns.append(col)
109+
else:
110+
# Column doesn't exist - add nulls
111+
LOGGER.debug(f"Adding missing column '{field_name}' with null values")
112+
new_columns.append(pa.nulls(len(data), type=target_type))
113+
added_columns.append(field_name)
114+
115+
# Log columns that will be dropped
116+
dropped_columns = [
117+
name for name in data.column_names
118+
if name not in target_fields
119+
]
120+
if dropped_columns:
121+
LOGGER.debug(
122+
f"Dropping extra columns not in target schema: {dropped_columns}"
123+
)
124+
125+
# Create new table
126+
adapted_table = pa.Table.from_arrays(
127+
new_columns,
128+
schema=target_schema
129+
)
130+
131+
# Log summary of adaptations
132+
LOGGER.debug("Schema adaptation summary:")
133+
if added_columns:
134+
LOGGER.debug(f"- Added columns (filled with nulls): {added_columns}")
135+
if dropped_columns:
136+
LOGGER.debug(f"- Dropped columns: {dropped_columns}")
137+
if type_conversions:
138+
LOGGER.debug(f"- Successful type conversions: {type_conversions}")
139+
if failed_conversions:
140+
LOGGER.debug(
141+
f"- Failed conversions (filled with nulls): {failed_conversions}"
142+
)
143+
144+
LOGGER.debug(f"Final adapted table schema: {adapted_table.schema}")
145+
return adapted_table
146+
147+
26148
def send_pyarrow_table_to_postgresql(
27149
data: Table,
28150
target_table_name: str,
@@ -34,19 +156,54 @@ def send_pyarrow_table_to_postgresql(
34156
postgres_password: str,
35157
replace: bool
36158
) -> int:
159+
LOGGER.debug(
160+
f"Preparing to send data to PostgreSQL table '{postgres_schema}.{target_table_name}'"
161+
)
162+
LOGGER.debug(f"Input table has {len(data)} rows")
163+
164+
# Get existing schema if table exists
165+
existing_schema = get_postgresql_table_schema(
166+
target_table_name,
167+
postgres_host,
168+
postgres_port,
169+
postgres_db,
170+
postgres_schema,
171+
postgres_user,
172+
postgres_password
173+
)
174+
175+
if existing_schema is not None:
176+
LOGGER.debug(f"Found existing table with schema: {existing_schema}")
177+
if not replace:
178+
LOGGER.debug("Adapting incoming data to match existing schema")
179+
data = adapt_table_to_schema(data, existing_schema)
180+
else:
181+
LOGGER.debug("Replace mode enabled - skipping schema adaptation")
182+
else:
183+
LOGGER.debug("No existing table found - will create new table")
184+
185+
# Proceed with ingestion
37186
total = 0
38-
39-
postgresql_full_uri = generate_postgresql_full_uri(postgres_host,
40-
postgres_port,
41-
postgres_db,
42-
postgres_user,
43-
postgres_password)
187+
postgresql_full_uri = generate_postgresql_full_uri(
188+
postgres_host,
189+
postgres_port,
190+
postgres_db,
191+
postgres_user,
192+
postgres_password
193+
)
194+
195+
LOGGER.debug("Connecting to PostgreSQL database")
44196
with dbapi.connect(postgresql_full_uri, autocommit=True) as conn:
45197
with conn.cursor() as curs:
198+
LOGGER.debug(
199+
f"Ingesting data with mode: {'replace' if replace else 'create_append'}"
200+
)
46201
total += curs.adbc_ingest(
47202
target_table_name,
48203
data,
49204
"replace" if replace else "create_append",
50-
db_schema_name=postgres_schema)
51-
205+
db_schema_name=postgres_schema
206+
)
207+
208+
LOGGER.debug(f"Successfully ingested {total} rows")
52209
return total

0 commit comments

Comments
 (0)