Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions airbyte_cdk/sql/shared/catalog_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,16 @@ def get_stream_properties(
def get_primary_keys(
self,
stream_name: str,
) -> list[str]:
"""Return the primary keys for the given stream."""
pks = self.get_configured_stream_info(stream_name).primary_key
) -> list[str] | None:
"""Return the primary keys for the given stream.

We return `source_defined_primary_key` if set, or `primary_key` otherwise. If both are set, we assume they should not should differ, since Airbyte data integrity constraints do not permit overruling a source's pre-defined primary keys. If neither is set, we return `None`.
"""
configured_stream = self.get_configured_stream_info(stream_name)
pks = configured_stream.stream.source_defined_primary_key or configured_stream.primary_key

if not pks:
return []
return None

normalized_pks: list[list[str]] = [
[LowerCaseNormalizer.normalize(c) for c in pk] for pk in pks
Expand Down
5 changes: 3 additions & 2 deletions airbyte_cdk/sql/shared/sql_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,8 @@ def _merge_temp_table_to_final_table(
nl = "\n"
columns = {self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)}
pk_columns = {
self._quote_identifier(c) for c in self.catalog_provider.get_primary_keys(stream_name)
self._quote_identifier(c)
for c in (self.catalog_provider.get_primary_keys(stream_name) or [])
}
non_pk_columns = columns - pk_columns
join_clause = f"{nl} AND ".join(f"tmp.{pk_col} = final.{pk_col}" for pk_col in pk_columns)
Expand Down Expand Up @@ -724,7 +725,7 @@ def _emulated_merge_temp_table_to_final_table(
"""
final_table = self._get_table_by_name(final_table_name)
temp_table = self._get_table_by_name(temp_table_name)
pk_columns = self.catalog_provider.get_primary_keys(stream_name)
pk_columns = self.catalog_provider.get_primary_keys(stream_name) or []

columns_to_update: set[str] = self._get_sql_column_definitions(
stream_name=stream_name
Expand Down
78 changes: 78 additions & 0 deletions unit_tests/sql/shared/test_catalog_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from unittest.mock import Mock

import pytest

from airbyte_cdk.models import AirbyteStream, ConfiguredAirbyteCatalog, ConfiguredAirbyteStream
from airbyte_cdk.sql.shared.catalog_providers import CatalogProvider


class TestCatalogProvider:
"""Test cases for CatalogProvider.get_primary_keys() method."""

@pytest.mark.parametrize(
"configured_primary_key,source_defined_primary_key,expected_result,test_description",
[
(["configured_id"], ["source_id"], ["source_id"], "prioritizes source when both set"),
([], ["source_id"], ["source_id"], "uses source when configured empty"),
(None, ["source_id"], ["source_id"], "uses source when configured None"),
(
["configured_id"],
[],
["configured_id"],
"falls back to configured when source empty",
),
(
["configured_id"],
None,
["configured_id"],
"falls back to configured when source None",
),
([], [], None, "returns None when both empty"),
(None, None, None, "returns None when both None"),
([], ["id1", "id2"], ["id1", "id2"], "handles composite keys from source"),
],
)
def test_get_primary_keys_parametrized(
self, configured_primary_key, source_defined_primary_key, expected_result, test_description
):
"""Test primary key fallback logic with various input combinations."""
configured_pk_wrapped = (
None
if configured_primary_key is None
else [[pk] for pk in configured_primary_key]
if configured_primary_key
else []
)
source_pk_wrapped = (
None
if source_defined_primary_key is None
else [[pk] for pk in source_defined_primary_key]
if source_defined_primary_key
else []
)

stream = AirbyteStream(
name="test_stream",
json_schema={
"type": "object",
"properties": {
"id": {"type": "string"},
"id1": {"type": "string"},
"id2": {"type": "string"},
},
},
supported_sync_modes=["full_refresh"],
source_defined_primary_key=source_pk_wrapped,
)
configured_stream = ConfiguredAirbyteStream(
stream=stream,
sync_mode="full_refresh",
destination_sync_mode="overwrite",
primary_key=configured_pk_wrapped,
)
catalog = ConfiguredAirbyteCatalog(streams=[configured_stream])

provider = CatalogProvider(catalog)
result = provider.get_primary_keys("test_stream")

assert result == expected_result
Loading