1+ import json
12from contextlib import contextmanager
23from dataclasses import dataclass , field
3- from typing import TYPE_CHECKING , Generator , Optional
4+ from typing import TYPE_CHECKING , Any , Generator , Optional
45
56import numpy as np
67import pandas as pd
1516 SourceRegistryEntry ,
1617)
1718from unstructured_ingest .v2 .processes .connectors .sql .sql import (
19+ _DATE_COLUMNS ,
1820 SQLAccessConfig ,
1921 SqlBatchFileData ,
2022 SQLConnectionConfig ,
2628 SQLUploaderConfig ,
2729 SQLUploadStager ,
2830 SQLUploadStagerConfig ,
31+ parse_date_string ,
2932)
3033
3134if TYPE_CHECKING :
3437
3538CONNECTOR_TYPE = "snowflake"
3639
40+ _ARRAY_COLUMNS = (
41+ "embeddings" ,
42+ "languages" ,
43+ "link_urls" ,
44+ "link_texts" ,
45+ "sent_from" ,
46+ "sent_to" ,
47+ "emphasized_text_contents" ,
48+ "emphasized_text_tags" ,
49+ )
50+
3751
3852class SnowflakeAccessConfig (SQLAccessConfig ):
3953 password : Optional [str ] = Field (default = None , description = "DB password" )
@@ -160,6 +174,42 @@ class SnowflakeUploader(SQLUploader):
160174 connector_type : str = CONNECTOR_TYPE
161175 values_delimiter : str = "?"
162176
177+ def prepare_data (
178+ self , columns : list [str ], data : tuple [tuple [Any , ...], ...]
179+ ) -> list [tuple [Any , ...]]:
180+ output = []
181+ for row in data :
182+ parsed = []
183+ for column_name , value in zip (columns , row ):
184+ if column_name in _DATE_COLUMNS :
185+ if value is None or pd .isna (value ): # pandas is nan
186+ parsed .append (None )
187+ else :
188+ parsed .append (parse_date_string (value ))
189+ elif column_name in _ARRAY_COLUMNS :
190+ if not isinstance (value , list ) and (
191+ value is None or pd .isna (value )
192+ ): # pandas is nan
193+ parsed .append (None )
194+ else :
195+ parsed .append (json .dumps (value ))
196+ else :
197+ parsed .append (value )
198+ output .append (tuple (parsed ))
199+ return output
200+
201+ def _parse_values (self , columns : list [str ]) -> str :
202+ return "," .join (
203+ [
204+ (
205+ f"PARSE_JSON({ self .values_delimiter } )"
206+ if col in _ARRAY_COLUMNS
207+ else self .values_delimiter
208+ )
209+ for col in columns
210+ ]
211+ )
212+
163213 def upload_dataframe (self , df : pd .DataFrame , file_data : FileData ) -> None :
164214 if self .can_delete ():
165215 self .delete_by_record_id (file_data = file_data )
@@ -173,10 +223,10 @@ def upload_dataframe(self, df: pd.DataFrame, file_data: FileData) -> None:
173223 self ._fit_to_schema (df = df )
174224
175225 columns = list (df .columns )
176- stmt = "INSERT INTO {table_name} ({columns}) VALUES( {values}) " .format (
226+ stmt = "INSERT INTO {table_name} ({columns}) SELECT {values}" .format (
177227 table_name = self .upload_config .table_name ,
178228 columns = "," .join (columns ),
179- values = "," . join ([ self .values_delimiter for _ in columns ] ),
229+ values = self ._parse_values ( columns ),
180230 )
181231 logger .info (
182232 f"writing a total of { len (df )} elements via"
0 commit comments