Skip to content

Commit 1e45e42

Browse files
authored
Fix: Resolve issue where mixed-case stream properties would result in missing data (#114)
1 parent 5a8ebf5 commit 1e45e42

File tree

9 files changed

+376
-63
lines changed

9 files changed

+376
-63
lines changed

airbyte/_processors/sql/base.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sqlalchemy
1414
import ulid
1515
from overrides import overrides
16+
from pandas import Index
1617
from sqlalchemy import (
1718
Column,
1819
Table,
@@ -29,7 +30,7 @@
2930

3031
from airbyte import exceptions as exc
3132
from airbyte._processors.base import RecordProcessor
32-
from airbyte._util.text_util import lower_case_set
33+
from airbyte._util.name_normalizers import LowerCaseNormalizer
3334
from airbyte.caches._catalog_manager import CatalogManager
3435
from airbyte.datasets._sql import CachedDataset
3536
from airbyte.progress import progress
@@ -73,9 +74,17 @@ class SqlProcessorBase(RecordProcessor):
7374
"""A base class to be used for SQL Caches."""
7475

7576
type_converter_class: type[SQLTypeConverter] = SQLTypeConverter
77+
"""The type converter class to use for converting JSON schema types to SQL types."""
78+
79+
normalizer = LowerCaseNormalizer
80+
"""The name normalizer to user for table and column name normalization."""
81+
7682
file_writer_class: type[FileWriterBase]
83+
"""The file writer class to use for writing files to the cache."""
7784

7885
supports_merge_insert = False
86+
"""True if the database supports the MERGE INTO syntax."""
87+
7988
use_singleton_connection = False # If true, the same connection is used for all operations.
8089

8190
# Constructor:
@@ -197,7 +206,7 @@ def get_sql_table_name(
197206

198207
# TODO: Add default prefix based on the source name.
199208

200-
return self._normalize_table_name(
209+
return self.normalizer.normalize(
201210
f"{table_prefix}{stream_name}{self.cache.table_suffix}",
202211
)
203212

@@ -324,7 +333,7 @@ def _get_temp_table_name(
324333
) -> str:
325334
"""Return a new (unique) temporary table name."""
326335
batch_id = batch_id or str(ulid.ULID())
327-
return self._normalize_table_name(f"{stream_name}_{batch_id}")
336+
return self.normalizer.normalize(f"{stream_name}_{batch_id}")
328337

329338
def _fully_qualified(
330339
self,
@@ -414,11 +423,11 @@ def _ensure_compatible_table_schema(
414423
stream_column_names: list[str] = json_schema["properties"].keys()
415424
table_column_names: list[str] = self.get_sql_table(stream_name).columns.keys()
416425

417-
lower_case_table_column_names = lower_case_set(table_column_names)
426+
lower_case_table_column_names = self.normalizer.normalize_set(table_column_names)
418427
missing_columns = [
419428
stream_col
420429
for stream_col in stream_column_names
421-
if stream_col.lower() not in lower_case_table_column_names
430+
if self.normalizer.normalize(stream_col) not in lower_case_table_column_names
422431
]
423432
if missing_columns:
424433
if raise_on_error:
@@ -452,17 +461,12 @@ def _create_table(
452461
"""
453462
_ = self._execute_sql(cmd)
454463

455-
def _normalize_column_name(
456-
self,
457-
raw_name: str,
458-
) -> str:
459-
return raw_name.lower().replace(" ", "_").replace("-", "_")
460-
461-
def _normalize_table_name(
464+
def _get_stream_properties(
462465
self,
463-
raw_name: str,
464-
) -> str:
465-
return raw_name.lower().replace(" ", "_").replace("-", "_")
466+
stream_name: str,
467+
) -> dict[str, dict]:
468+
"""Return the names of the top-level properties for the given stream."""
469+
return self._get_stream_json_schema(stream_name)["properties"]
466470

467471
@final
468472
def _get_sql_column_definitions(
@@ -471,9 +475,9 @@ def _get_sql_column_definitions(
471475
) -> dict[str, sqlalchemy.types.TypeEngine]:
472476
"""Return the column definitions for the given stream."""
473477
columns: dict[str, sqlalchemy.types.TypeEngine] = {}
474-
properties = self._get_stream_json_schema(stream_name)["properties"]
478+
properties = self._get_stream_properties(stream_name)
475479
for property_name, json_schema_property_def in properties.items():
476-
clean_prop_name = self._normalize_column_name(property_name)
480+
clean_prop_name = self.normalizer.normalize(property_name)
477481
columns[clean_prop_name] = self.type_converter.to_sql_type(
478482
json_schema_property_def,
479483
)
@@ -635,6 +639,12 @@ def _write_files_to_new_table(
635639
},
636640
)
637641

642+
# Normalize all column names to lower case.
643+
dataframe.columns = Index(
644+
[LowerCaseNormalizer.normalize(col) for col in dataframe.columns]
645+
)
646+
647+
# Write the data to the table.
638648
dataframe.to_sql(
639649
temp_table_name,
640650
self.get_sql_alchemy_url(),

airbyte/_processors/sql/duckdb.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def _write_files_to_new_table(
8484
stream_name=stream_name,
8585
batch_id=batch_id,
8686
)
87+
properties_list = list(self._get_stream_properties(stream_name).keys())
8788
columns_list = list(self._get_sql_column_definitions(stream_name=stream_name).keys())
8889
columns_list_str = indent(
8990
"\n, ".join([self._quote_identifier(c) for c in columns_list]),
@@ -93,9 +94,14 @@ def _write_files_to_new_table(
9394
columns_type_map = indent(
9495
"\n, ".join(
9596
[
96-
f"{self._quote_identifier(c)}: "
97-
f"{self._get_sql_column_definitions(stream_name)[c]!s}"
98-
for c in columns_list
97+
self._quote_identifier(prop_name)
98+
+ ": "
99+
+ str(
100+
self._get_sql_column_definitions(stream_name)[
101+
self.normalizer.normalize(prop_name)
102+
]
103+
)
104+
for prop_name in properties_list
99105
]
100106
),
101107
" ",

airbyte/_processors/sql/snowflake.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,16 @@ def _write_files_to_new_table(
6767
]
6868
)
6969
self._execute_sql(put_files_statements)
70-
70+
properties_list: list[str] = list(self._get_stream_properties(stream_name).keys())
7171
columns_list = [
7272
self._quote_identifier(c)
7373
for c in list(self._get_sql_column_definitions(stream_name).keys())
7474
]
7575
files_list = ", ".join([f"'{f.name}'" for f in files])
7676
columns_list_str: str = indent("\n, ".join(columns_list), " " * 12)
77-
variant_cols_str: str = ("\n" + " " * 21 + ", ").join([f"$1:{col}" for col in columns_list])
77+
variant_cols_str: str = ("\n" + " " * 21 + ", ").join(
78+
[f"$1:{col}" for col in properties_list]
79+
)
7880
copy_statement = dedent(
7981
f"""
8082
COPY INTO {temp_table_name}

airbyte/_util/name_normalizers.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
2+
"""Name normalizer classes."""
3+
4+
from __future__ import annotations
5+
6+
import abc
7+
from typing import TYPE_CHECKING, Any
8+
9+
10+
if TYPE_CHECKING:
11+
from collections.abc import Iterable, Iterator
12+
13+
14+
class NameNormalizerBase(abc.ABC):
15+
"""Abstract base class for name normalizers."""
16+
17+
@staticmethod
18+
@abc.abstractmethod
19+
def normalize(name: str) -> str:
20+
"""Return the normalized name."""
21+
...
22+
23+
@classmethod
24+
def normalize_set(cls, str_iter: Iterable[str]) -> set[str]:
25+
"""Converts string iterable to a set of lower case strings."""
26+
return {cls.normalize(s) for s in str_iter}
27+
28+
@classmethod
29+
def normalize_list(cls, str_iter: Iterable[str]) -> list[str]:
30+
"""Converts string iterable to a list of lower case strings."""
31+
return [cls.normalize(s) for s in str_iter]
32+
33+
@classmethod
34+
def check_matched(cls, name1: str, name2: str) -> bool:
35+
"""Return True if the two names match after each is normalized."""
36+
return cls.normalize(name1) == cls.normalize(name2)
37+
38+
@classmethod
39+
def check_normalized(cls, name: str) -> bool:
40+
"""Return True if the name is already normalized."""
41+
return cls.normalize(name) == name
42+
43+
44+
class LowerCaseNormalizer(NameNormalizerBase):
45+
"""A name normalizer that converts names to lower case."""
46+
47+
@staticmethod
48+
def normalize(name: str) -> str:
49+
"""Return the normalized name."""
50+
return name.lower().replace(" ", "_").replace("-", "_")
51+
52+
53+
class CaseInsensitiveDict(dict[str, Any]):
54+
"""A case-aware, case-insensitive dictionary implementation.
55+
56+
It has these behaviors:
57+
- When a key is retrieved, deleted, or checked for existence, it is always checked in a
58+
case-insensitive manner.
59+
- The original case is stored in a separate dictionary, so that the original case can be
60+
retrieved when needed.
61+
62+
There are two ways to store keys internally:
63+
- If normalize_keys is True, the keys are normalized using the given normalizer.
64+
- If normalize_keys is False, the original case of the keys is stored.
65+
66+
In regards to missing values, the dictionary accepts an 'expected_keys' input. When set, the
67+
dictionary will be initialized with the given keys. If a key is not found in the input data, it
68+
will be initialized with a value of None. When provided, the 'expected_keys' input will also
69+
determine the original case of the keys.
70+
"""
71+
72+
def _display_case(self, key: str) -> str:
73+
"""Return the original case of the key."""
74+
return self._pretty_case_keys[self._normalizer.normalize(key)]
75+
76+
def _index_case(self, key: str) -> str:
77+
"""Return the internal case of the key.
78+
79+
If normalize_keys is True, return the normalized key.
80+
Otherwise, return the original case of the key.
81+
"""
82+
if self._normalize_keys:
83+
return self._normalizer.normalize(key)
84+
85+
return self._display_case(key)
86+
87+
def __init__(
88+
self,
89+
from_dict: dict,
90+
*,
91+
normalize_keys: bool = True,
92+
normalizer: type[NameNormalizerBase] | None = None,
93+
expected_keys: list[str] | None = None,
94+
) -> None:
95+
"""Initialize the dictionary with the given data.
96+
97+
If normalize_keys is True, the keys will be normalized using the given normalizer.
98+
If expected_keys is provided, the dictionary will be initialized with the given keys.
99+
"""
100+
# If no normalizer is provided, use LowerCaseNormalizer.
101+
self._normalize_keys = normalize_keys
102+
self._normalizer: type[NameNormalizerBase] = normalizer or LowerCaseNormalizer
103+
104+
# If no expected keys are provided, use all keys from the input dictionary.
105+
if not expected_keys:
106+
expected_keys = list(from_dict.keys())
107+
108+
# Store a lookup from normalized keys to pretty cased (originally cased) keys.
109+
self._pretty_case_keys: dict[str, str] = {
110+
self._normalizer.normalize(pretty_case.lower()): pretty_case
111+
for pretty_case in expected_keys
112+
}
113+
114+
if normalize_keys:
115+
index_keys = [self._normalizer.normalize(key) for key in expected_keys]
116+
else:
117+
index_keys = expected_keys
118+
119+
self.update({k: None for k in index_keys}) # Start by initializing all values to None
120+
for k, v in from_dict.items():
121+
self[self._index_case(k)] = v
122+
123+
def __getitem__(self, key: str) -> Any: # noqa: ANN401
124+
if super().__contains__(key):
125+
return super().__getitem__(key)
126+
127+
if super().__contains__(self._index_case(key)):
128+
return super().__getitem__(self._index_case(key))
129+
130+
raise KeyError(key)
131+
132+
def __setitem__(self, key: str, value: Any) -> None: # noqa: ANN401
133+
if super().__contains__(key):
134+
super().__setitem__(key, value)
135+
return
136+
137+
if super().__contains__(self._index_case(key)):
138+
super().__setitem__(self._index_case(key), value)
139+
return
140+
141+
# Store the pretty cased (originally cased) key:
142+
self._pretty_case_keys[self._normalizer.normalize(key)] = key
143+
144+
# Store the data with the normalized key:
145+
super().__setitem__(self._index_case(key), value)
146+
147+
def __delitem__(self, key: str) -> None:
148+
if super().__contains__(key):
149+
super().__delitem__(key)
150+
return
151+
152+
if super().__contains__(self._index_case(key)):
153+
super().__delitem__(self._index_case(key))
154+
return
155+
156+
raise KeyError(key)
157+
158+
def __contains__(self, key: object) -> bool:
159+
assert isinstance(key, str), "Key must be a string."
160+
return super().__contains__(key) or super().__contains__(self._index_case(key))
161+
162+
def __iter__(self) -> Any: # noqa: ANN401
163+
return iter(super().__iter__())
164+
165+
def __len__(self) -> int:
166+
return super().__len__()
167+
168+
def __eq__(self, other: object) -> bool:
169+
if isinstance(other, CaseInsensitiveDict):
170+
return dict(self) == dict(other)
171+
172+
if isinstance(other, dict):
173+
return {k.lower(): v for k, v in self.items()} == {
174+
k.lower(): v for k, v in other.items()
175+
}
176+
return False
177+
178+
179+
def normalize_records(
180+
records: Iterable[dict[str, Any]],
181+
expected_keys: list[str],
182+
) -> Iterator[CaseInsensitiveDict]:
183+
"""Add missing columns to the record with null values.
184+
185+
Also conform the column names to the case in the catalog.
186+
187+
This is a generator that yields CaseInsensitiveDicts, which allows for case-insensitive
188+
lookups of columns. This is useful because the case of the columns in the records may
189+
not match the case of the columns in the catalog.
190+
"""
191+
yield from (
192+
CaseInsensitiveDict(
193+
from_dict=record,
194+
expected_keys=expected_keys,
195+
)
196+
for record in records
197+
)
198+
199+
200+
__all__ = [
201+
"NameNormalizerBase",
202+
"LowerCaseNormalizer",
203+
"CaseInsensitiveDict",
204+
"normalize_records",
205+
]

airbyte/_util/text_util.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

0 commit comments

Comments
 (0)