Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions airbyte_cdk/sql/shared/catalog_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,12 @@ def get_primary_keys(
stream_name: str,
) -> list[str]:
"""Return the primary keys for the given stream."""
pks = self.get_configured_stream_info(stream_name).primary_key
if not pks:
return []
configured_stream = self.get_configured_stream_info(stream_name)
pks = (
configured_stream.primary_key
or configured_stream.stream.source_defined_primary_key
or []
)

normalized_pks: list[list[str]] = [
[LowerCaseNormalizer.normalize(c) for c in pk] for pk in pks
Expand Down
66 changes: 66 additions & 0 deletions unit_tests/sql/shared/test_catalog_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from unittest.mock import Mock

import pytest

from airbyte_cdk.models import AirbyteStream, ConfiguredAirbyteCatalog, ConfiguredAirbyteStream
from airbyte_cdk.sql.shared.catalog_providers import CatalogProvider


class TestCatalogProvider:
"""Test cases for CatalogProvider.get_primary_keys() method."""

@pytest.mark.parametrize(
"configured_primary_key,source_defined_primary_key,expected_result,test_description",
[
(["configured_id"], ["source_id"], ["configured_id"], "uses configured when both set"),
([], ["source_id"], ["source_id"], "falls back to source when configured empty"),
(None, ["source_id"], ["source_id"], "falls back to source when configured None"),
([], [], [], "returns empty when both empty"),
(None, None, [], "returns empty when both None"),
([], ["id1", "id2"], ["id1", "id2"], "handles composite keys from source"),
],
)
def test_get_primary_keys_parametrized(
self, configured_primary_key, source_defined_primary_key, expected_result, test_description
):
"""Test primary key fallback logic with various input combinations."""
configured_pk_wrapped = (
None
if configured_primary_key is None
else [[pk] for pk in configured_primary_key]
if configured_primary_key
else []
)
source_pk_wrapped = (
None
if source_defined_primary_key is None
else [[pk] for pk in source_defined_primary_key]
if source_defined_primary_key
else []
)

stream = AirbyteStream(
name="test_stream",
json_schema={
"type": "object",
"properties": {
"id": {"type": "string"},
"id1": {"type": "string"},
"id2": {"type": "string"},
},
},
supported_sync_modes=["full_refresh"],
source_defined_primary_key=source_pk_wrapped,
)
configured_stream = ConfiguredAirbyteStream(
stream=stream,
sync_mode="full_refresh",
destination_sync_mode="overwrite",
primary_key=configured_pk_wrapped,
)
catalog = ConfiguredAirbyteCatalog(streams=[configured_stream])

provider = CatalogProvider(catalog)
result = provider.get_primary_keys("test_stream")

assert result == expected_result
Loading