Skip to content

Commit 51026ee

Browse files
Feat: Add Cortex-compatible Snowflake SQL processor for storing vector data; Enable native merge upsert for Snowflake caches (#203)
Co-authored-by: Aaron ("AJ") Steers <aj@airbyte.io>
1 parent cd1327a commit 51026ee

File tree

10 files changed

+836
-276
lines changed

10 files changed

+836
-276
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,18 @@
11
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
22
"""SQL processors."""
3+
4+
from __future__ import annotations
5+
6+
from airbyte._processors.sql.snowflakecortex import (
7+
SnowflakeCortexSqlProcessor,
8+
SnowflakeCortexTypeConverter,
9+
)
10+
11+
12+
__all__ = [
13+
# Classes
14+
"SnowflakeCortexSqlProcessor",
15+
"SnowflakeCortexTypeConverter",
16+
# modules
17+
"snowflakecortex",
18+
]

airbyte/_processors/sql/snowflake.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class SnowflakeSqlProcessor(SqlProcessorBase):
5151

5252
file_writer_class = JsonlWriter
5353
type_converter_class = SnowflakeTypeConverter
54+
supports_merge_insert = True
5455

5556
@overrides
5657
def _write_files_to_new_table(
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
"""A Snowflake vector store implementation of the SQL processor."""
3+
4+
from __future__ import annotations
5+
6+
from textwrap import dedent, indent
7+
from typing import TYPE_CHECKING
8+
9+
import sqlalchemy
10+
from overrides import overrides
11+
from sqlalchemy import text
12+
13+
from airbyte import exceptions as exc
14+
from airbyte._processors.base import RecordProcessor
15+
from airbyte._processors.sql.snowflake import SnowflakeSqlProcessor, SnowflakeTypeConverter
16+
from airbyte.caches._catalog_manager import CatalogManager
17+
18+
19+
if TYPE_CHECKING:
20+
from pathlib import Path
21+
22+
from sqlalchemy.engine import Connection, Engine
23+
24+
from airbyte_cdk.models import ConfiguredAirbyteCatalog
25+
26+
from airbyte._processors.file.base import FileWriterBase
27+
from airbyte.caches.base import CacheBase
28+
29+
30+
class SnowflakeCortexTypeConverter(SnowflakeTypeConverter):
31+
"""A class to convert array type into vector."""
32+
33+
def __init__(
34+
self,
35+
conversion_map: dict | None = None,
36+
*,
37+
vector_length: int,
38+
) -> None:
39+
self.vector_length = vector_length
40+
super().__init__(conversion_map)
41+
42+
@overrides
43+
def to_sql_type(
44+
self,
45+
json_schema_property_def: dict[str, str | dict | list],
46+
) -> sqlalchemy.types.TypeEngine:
47+
"""Convert a value to a SQL type."""
48+
sql_type = super().to_sql_type(json_schema_property_def)
49+
if isinstance(sql_type, sqlalchemy.types.ARRAY):
50+
# SQLAlchemy doesn't yet support the `VECTOR` data type.
51+
# We may want to remove this or update once this resolves:
52+
# https://github.com/snowflakedb/snowflake-sqlalchemy/issues/499
53+
return f"VECTOR(FLOAT, {self.vector_length})"
54+
55+
return sql_type
56+
57+
58+
class SnowflakeCortexSqlProcessor(SnowflakeSqlProcessor):
59+
"""A Snowflake implementation for use with Cortex functions."""
60+
61+
supports_merge_insert = True
62+
63+
def __init__(
64+
self,
65+
cache: CacheBase,
66+
catalog: ConfiguredAirbyteCatalog,
67+
vector_length: int,
68+
source_name: str,
69+
stream_names: set[str],
70+
*,
71+
file_writer: FileWriterBase | None = None,
72+
) -> None:
73+
"""Custom initialization: Initialize type_converter with vector_length."""
74+
self._catalog = catalog
75+
# to-do: see if we can get rid of the following assignment
76+
self.source_catalog = catalog
77+
self._vector_length = vector_length
78+
self._engine: Engine | None = None
79+
self._connection_to_reuse: Connection | None = None
80+
81+
# call base class to do necessary initialization
82+
RecordProcessor.__init__(self, cache=cache, catalog_manager=None)
83+
self._ensure_schema_exists()
84+
self._catalog_manager = CatalogManager(
85+
engine=self.get_sql_engine(),
86+
table_name_resolver=lambda stream_name: self.get_sql_table_name(stream_name),
87+
)
88+
89+
# TODO: read streams and source from catalog if not provided
90+
91+
# initialize catalog manager by registering source
92+
self.register_source(
93+
source_name=source_name,
94+
incoming_source_catalog=self._catalog,
95+
stream_names=stream_names,
96+
)
97+
self.file_writer = file_writer or self.file_writer_class(cache)
98+
self.type_converter = SnowflakeCortexTypeConverter(vector_length=vector_length)
99+
self._cached_table_definitions: dict[str, sqlalchemy.Table] = {}
100+
101+
def _get_column_list_from_table(
102+
self,
103+
table_name: str,
104+
) -> list[str]:
105+
"""Get column names for passed stream.
106+
107+
This is overridden due to lack of SQLAlchemy compatibility for the
108+
`VECTOR` data type.
109+
"""
110+
conn: Connection = self.cache.get_vendor_client()
111+
cursor = conn.cursor()
112+
cursor.execute(f"DESCRIBE TABLE {table_name};")
113+
results = cursor.fetchall()
114+
column_names = [row[0].lower() for row in results]
115+
cursor.close()
116+
conn.close()
117+
return column_names
118+
119+
@overrides
120+
def _ensure_compatible_table_schema(
121+
self,
122+
stream_name: str,
123+
*,
124+
raise_on_error: bool = True,
125+
) -> bool:
126+
"""Read the exsting table schema using Snowflake python connector"""
127+
json_schema = self.get_stream_json_schema(stream_name)
128+
stream_column_names: list[str] = json_schema["properties"].keys()
129+
table_column_names: list[str] = self._get_column_list_from_table(stream_name)
130+
131+
lower_case_table_column_names = self.normalizer.normalize_set(table_column_names)
132+
missing_columns = [
133+
stream_col
134+
for stream_col in stream_column_names
135+
if self.normalizer.normalize(stream_col) not in lower_case_table_column_names
136+
]
137+
# TODO: shouldn't we just return false here, so missing tables can be created ?
138+
if missing_columns:
139+
if raise_on_error:
140+
raise exc.PyAirbyteCacheTableValidationError(
141+
violation="Cache table is missing expected columns.",
142+
context={
143+
"stream_column_names": stream_column_names,
144+
"table_column_names": table_column_names,
145+
"missing_columns": missing_columns,
146+
},
147+
)
148+
return False # Some columns are missing.
149+
150+
return True # All columns exist.
151+
152+
@overrides
153+
def _write_files_to_new_table(
154+
self,
155+
files: list[Path],
156+
stream_name: str,
157+
batch_id: str,
158+
) -> str:
159+
"""Write files to a new table."""
160+
temp_table_name = self._create_table_for_loading(
161+
stream_name=stream_name,
162+
batch_id=batch_id,
163+
)
164+
internal_sf_stage_name = f"@%{temp_table_name}"
165+
166+
def path_str(path: Path) -> str:
167+
return str(path.absolute()).replace("\\", "\\\\")
168+
169+
put_files_statements = "\n".join(
170+
[f"PUT 'file://{path_str(file_path)}' {internal_sf_stage_name};" for file_path in files]
171+
)
172+
self._execute_sql(put_files_statements)
173+
columns_list = [
174+
self._quote_identifier(c)
175+
for c in list(self._get_sql_column_definitions(stream_name).keys())
176+
]
177+
files_list = ", ".join([f"'{f.name}'" for f in files])
178+
columns_list_str: str = indent("\n, ".join(columns_list), " " * 12)
179+
180+
# following two lines are different from SnowflakeSqlProcessor
181+
vector_suffix = f"::Vector(Float, {self._vector_length})"
182+
variant_cols_str: str = ("\n" + " " * 21 + ", ").join(
183+
[
184+
f"$1:{self.normalizer.normalize(col)}{vector_suffix if 'embedding' in col else ''}"
185+
for col in columns_list
186+
]
187+
)
188+
189+
copy_statement = dedent(
190+
f"""
191+
COPY INTO {temp_table_name}
192+
(
193+
{columns_list_str}
194+
)
195+
FROM (
196+
SELECT {variant_cols_str}
197+
FROM {internal_sf_stage_name}
198+
)
199+
FILES = ( {files_list} )
200+
FILE_FORMAT = ( TYPE = JSON )
201+
;
202+
"""
203+
)
204+
self._execute_sql(copy_statement)
205+
return temp_table_name
206+
207+
@overrides
208+
def _add_missing_columns_to_table(
209+
self,
210+
stream_name: str,
211+
table_name: str,
212+
) -> None:
213+
"""Use Snowflake Python connector to add new columns to the table"""
214+
columns = self._get_sql_column_definitions(stream_name)
215+
existing_columns = self._get_column_list_from_table(table_name)
216+
for column_name, column_type in columns.items():
217+
if column_name not in existing_columns:
218+
self._add_new_column_to_table(table_name, column_name, column_type)
219+
self._invalidate_table_cache(table_name)
220+
pass
221+
222+
def _add_new_column_to_table(
223+
self,
224+
table_name: str,
225+
column_name: str,
226+
column_type: sqlalchemy.types.TypeEngine,
227+
) -> None:
228+
conn: Connection = self.cache.get_vendor_client()
229+
cursor = conn.cursor()
230+
cursor.execute(
231+
text(
232+
f"ALTER TABLE {self._fully_qualified(table_name)} "
233+
f"ADD COLUMN {column_name} {column_type}"
234+
),
235+
)
236+
cursor.close()
237+
conn.close()

airbyte/caches/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ def get_database_name(self) -> str:
7474
"""Return the name of the database."""
7575
...
7676

77+
def get_vendor_client(self) -> object:
78+
"""Alternate (non-SQLAlchemy) way of getting database connection"""
79+
msg = "This method needs to be implemented for specific databases"
80+
raise NotImplementedError(msg)
81+
7782
@final
7883
@property
7984
def streams(

airbyte/caches/snowflake.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from __future__ import annotations
2323

2424
from overrides import overrides
25+
from snowflake import connector
2526
from snowflake.sqlalchemy import URL
2627

2728
from airbyte._processors.sql.base import RecordDedupeMode
@@ -62,6 +63,18 @@ def get_sql_alchemy_url(self) -> SecretString:
6263
)
6364
)
6465

66+
def get_vendor_client(self) -> object:
67+
"""Return the Snowflake connection object."""
68+
return connector.connect(
69+
user=self.username,
70+
password=self.password,
71+
account=self.account,
72+
warehouse=self.warehouse,
73+
database=self.database,
74+
schema=self.schema_name,
75+
role=self.role,
76+
)
77+
6578
@overrides
6679
def get_database_name(self) -> str:
6780
"""Return the name of the database."""

airbyte/types.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# We include them here for completeness.
2626
"object": sqlalchemy.types.JSON,
2727
"array": sqlalchemy.types.JSON,
28+
"vector_array": sqlalchemy.types.ARRAY,
2829
}
2930

3031

@@ -82,6 +83,9 @@ def _get_airbyte_type( # noqa: PLR0911 # Too many return statements
8283

8384
return "array", None
8485

86+
if json_schema_type == "vector_array":
87+
return "vector_array", "Float"
88+
8589
err_msg = f"Could not determine airbyte type from JSON schema type: {json_schema_property_def}"
8690
raise SQLTypeConversionError(err_msg)
8791

@@ -110,13 +114,16 @@ def get_json_type(cls) -> sqlalchemy.types.TypeEngine:
110114
"""Get the type to use for nested JSON data."""
111115
return sqlalchemy.types.JSON()
112116

113-
def to_sql_type(
117+
def to_sql_type( # noqa: PLR0911 # Too many return statements
114118
self,
115119
json_schema_property_def: dict[str, str | dict | list],
116120
) -> sqlalchemy.types.TypeEngine:
117121
"""Convert a value to a SQL type."""
118122
try:
119123
airbyte_type, _ = _get_airbyte_type(json_schema_property_def)
124+
# to-do - is there a better way to check the following
125+
if airbyte_type == "vector_array":
126+
return sqlalchemy.types.ARRAY(sqlalchemy.types.Float())
120127
sql_type = self.conversion_map[airbyte_type]
121128
except SQLTypeConversionError:
122129
print(f"Could not determine airbyte type from JSON schema: {json_schema_property_def}")

0 commit comments

Comments
 (0)