Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions airbyte_cdk/sql/shared/catalog_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,18 @@ def get_stream_properties(
def get_primary_keys(
self,
stream_name: str,
) -> list[str]:
"""Return the primary keys for the given stream."""
pks = self.get_configured_stream_info(stream_name).primary_key
) -> list[str] | None:
"""Return the primary keys for the given stream.

We return `source_defined_primary_key` if set, or `primary_key` otherwise. If both are set,
we assume they should not should differ, since Airbyte data integrity constraints do not
permit overruling a source's pre-defined primary keys. If neither is set, we return `None`.
"""
configured_stream = self.get_configured_stream_info(stream_name)
pks = configured_stream.stream.source_defined_primary_key or configured_stream.primary_key

if not pks:
return []
return None

normalized_pks: list[list[str]] = [
[LowerCaseNormalizer.normalize(c) for c in pk] for pk in pks
Expand Down
15 changes: 12 additions & 3 deletions airbyte_cdk/sql/shared/sql_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,9 +666,13 @@ def _merge_temp_table_to_final_table(
"""
nl = "\n"
columns = {self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)}
pk_columns = {
self._quote_identifier(c) for c in self.catalog_provider.get_primary_keys(stream_name)
}
primary_keys = self.catalog_provider.get_primary_keys(stream_name)
if not primary_keys:
raise exc.AirbyteInternalError(
message="Cannot merge tables without primary keys. Primary keys are required for merge operations.",
context={"stream_name": stream_name},
)
pk_columns = {self._quote_identifier(c) for c in primary_keys}
non_pk_columns = columns - pk_columns
join_clause = f"{nl} AND ".join(f"tmp.{pk_col} = final.{pk_col}" for pk_col in pk_columns)
set_clause = f"{nl} , ".join(f"{col} = tmp.{col}" for col in non_pk_columns)
Expand Down Expand Up @@ -725,6 +729,11 @@ def _emulated_merge_temp_table_to_final_table(
final_table = self._get_table_by_name(final_table_name)
temp_table = self._get_table_by_name(temp_table_name)
pk_columns = self.catalog_provider.get_primary_keys(stream_name)
if not pk_columns:
raise exc.AirbyteInternalError(
message="Cannot merge tables without primary keys. Primary keys are required for merge operations.",
context={"stream_name": stream_name},
)

columns_to_update: set[str] = self._get_sql_column_definitions(
stream_name=stream_name
Expand Down
78 changes: 78 additions & 0 deletions unit_tests/sql/shared/test_catalog_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from unittest.mock import Mock

import pytest

from airbyte_cdk.models import AirbyteStream, ConfiguredAirbyteCatalog, ConfiguredAirbyteStream
from airbyte_cdk.sql.shared.catalog_providers import CatalogProvider


class TestCatalogProvider:
"""Test cases for CatalogProvider.get_primary_keys() method."""

@pytest.mark.parametrize(
"configured_primary_key,source_defined_primary_key,expected_result,test_description",
[
(["configured_id"], ["source_id"], ["source_id"], "prioritizes source when both set"),
([], ["source_id"], ["source_id"], "uses source when configured empty"),
(None, ["source_id"], ["source_id"], "uses source when configured None"),
(
["configured_id"],
[],
["configured_id"],
"falls back to configured when source empty",
),
(
["configured_id"],
None,
["configured_id"],
"falls back to configured when source None",
),
([], [], None, "returns None when both empty"),
(None, None, None, "returns None when both None"),
([], ["id1", "id2"], ["id1", "id2"], "handles composite keys from source"),
],
)
def test_get_primary_keys_parametrized(
self, configured_primary_key, source_defined_primary_key, expected_result, test_description
):
"""Test primary key fallback logic with various input combinations."""
configured_pk_wrapped = (
None
if configured_primary_key is None
else [[pk] for pk in configured_primary_key]
if configured_primary_key
else []
)
source_pk_wrapped = (
None
if source_defined_primary_key is None
else [[pk] for pk in source_defined_primary_key]
if source_defined_primary_key
else []
)

stream = AirbyteStream(
name="test_stream",
json_schema={
"type": "object",
"properties": {
"id": {"type": "string"},
"id1": {"type": "string"},
"id2": {"type": "string"},
},
},
supported_sync_modes=["full_refresh"],
source_defined_primary_key=source_pk_wrapped,
)
configured_stream = ConfiguredAirbyteStream(
stream=stream,
sync_mode="full_refresh",
destination_sync_mode="overwrite",
primary_key=configured_pk_wrapped,
)
catalog = ConfiguredAirbyteCatalog(streams=[configured_stream])

provider = CatalogProvider(catalog)
result = provider.get_primary_keys("test_stream")

assert result == expected_result
Loading