Skip to content

Commit 56ffc4c

Browse files
Fix snowflake uploader array bind variable (#382)
* Fix snowflake uploader issue with array variable binding * Update version and changelog: Fix snowflake uploader issue with array variable binding * Changelog typo fix
1 parent 9589843 commit 56ffc4c

File tree

3 files changed

+59
-10
lines changed

3 files changed

+59
-10
lines changed

CHANGELOG.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
## 0.5.2-dev1
1+
## 0.5.2
22

3-
### Enchancements
3+
### Enhancements
44

5+
* **Only embed elements with text** - Only embed elements with text to avoid errors from embedders and optimize calls to APIs.
56
* **Improved google drive precheck mechanism**
67
* **Added integration tests for google drive precheck and connector**
78

8-
## 0.5.2-dev0
9-
10-
### Enhancements
9+
### Fixes
1110

12-
* **Only embed elements with text** - Only embed elements with text to avoid errors from embedders and optimize calls to APIs.
11+
* **Fix Snowflake Uploader error with array variable binding**
1312

1413
## 0.5.1
1514

unstructured_ingest/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.2-dev1" # pragma: no cover
1+
__version__ = "0.5.2" # pragma: no cover

unstructured_ingest/v2/processes/connectors/sql/snowflake.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import json
12
from contextlib import contextmanager
23
from dataclasses import dataclass, field
3-
from typing import TYPE_CHECKING, Generator, Optional
4+
from typing import TYPE_CHECKING, Any, Generator, Optional
45

56
import numpy as np
67
import pandas as pd
@@ -15,6 +16,7 @@
1516
SourceRegistryEntry,
1617
)
1718
from unstructured_ingest.v2.processes.connectors.sql.sql import (
19+
_DATE_COLUMNS,
1820
SQLAccessConfig,
1921
SqlBatchFileData,
2022
SQLConnectionConfig,
@@ -26,6 +28,7 @@
2628
SQLUploaderConfig,
2729
SQLUploadStager,
2830
SQLUploadStagerConfig,
31+
parse_date_string,
2932
)
3033

3134
if TYPE_CHECKING:
@@ -34,6 +37,17 @@
3437

3538
CONNECTOR_TYPE = "snowflake"
3639

40+
_ARRAY_COLUMNS = (
41+
"embeddings",
42+
"languages",
43+
"link_urls",
44+
"link_texts",
45+
"sent_from",
46+
"sent_to",
47+
"emphasized_text_contents",
48+
"emphasized_text_tags",
49+
)
50+
3751

3852
class SnowflakeAccessConfig(SQLAccessConfig):
3953
password: Optional[str] = Field(default=None, description="DB password")
@@ -160,6 +174,42 @@ class SnowflakeUploader(SQLUploader):
160174
connector_type: str = CONNECTOR_TYPE
161175
values_delimiter: str = "?"
162176

177+
def prepare_data(
178+
self, columns: list[str], data: tuple[tuple[Any, ...], ...]
179+
) -> list[tuple[Any, ...]]:
180+
output = []
181+
for row in data:
182+
parsed = []
183+
for column_name, value in zip(columns, row):
184+
if column_name in _DATE_COLUMNS:
185+
if value is None or pd.isna(value): # pandas is nan
186+
parsed.append(None)
187+
else:
188+
parsed.append(parse_date_string(value))
189+
elif column_name in _ARRAY_COLUMNS:
190+
if not isinstance(value, list) and (
191+
value is None or pd.isna(value)
192+
): # pandas is nan
193+
parsed.append(None)
194+
else:
195+
parsed.append(json.dumps(value))
196+
else:
197+
parsed.append(value)
198+
output.append(tuple(parsed))
199+
return output
200+
201+
def _parse_values(self, columns: list[str]) -> str:
202+
return ",".join(
203+
[
204+
(
205+
f"PARSE_JSON({self.values_delimiter})"
206+
if col in _ARRAY_COLUMNS
207+
else self.values_delimiter
208+
)
209+
for col in columns
210+
]
211+
)
212+
163213
def upload_dataframe(self, df: pd.DataFrame, file_data: FileData) -> None:
164214
if self.can_delete():
165215
self.delete_by_record_id(file_data=file_data)
@@ -173,10 +223,10 @@ def upload_dataframe(self, df: pd.DataFrame, file_data: FileData) -> None:
173223
self._fit_to_schema(df=df)
174224

175225
columns = list(df.columns)
176-
stmt = "INSERT INTO {table_name} ({columns}) VALUES({values})".format(
226+
stmt = "INSERT INTO {table_name} ({columns}) SELECT {values}".format(
177227
table_name=self.upload_config.table_name,
178228
columns=",".join(columns),
179-
values=",".join([self.values_delimiter for _ in columns]),
229+
values=self._parse_values(columns),
180230
)
181231
logger.info(
182232
f"writing a total of {len(df)} elements via"

0 commit comments

Comments
 (0)