Skip to content

Commit aa5d89f

Browse files
authored
📒 feat: Support CSV loading for non-UTF-8 files (#169)
- Added function to determine file encoding based on BOM markers. - Introduced function to remove temporary UTF-8 files created during encoding conversion. - Updated CSV loading logic to handle non-UTF-8 encodings by creating temporary UTF-8 files for processing. - Integrated cleanup functionality in embedding routes to ensure temporary files are removed after processing.
1 parent dce6324 commit aa5d89f

File tree

2 files changed

+131
-16
lines changed

2 files changed

+131
-16
lines changed

app/routes/document_routes.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@
3131
QueryMultipleBody,
3232
)
3333
from app.services.vector_store.async_pg_vector import AsyncPgVector
34-
from app.utils.document_loader import get_loader, clean_text, process_documents
34+
from app.utils.document_loader import (
35+
get_loader,
36+
clean_text,
37+
process_documents,
38+
cleanup_temp_encoding_file,
39+
)
3540
from app.utils.health import is_health_ok
3641

3742
router = APIRouter()
@@ -83,8 +88,12 @@ async def health_check():
8388
async def get_documents_by_ids(request: Request, ids: list[str] = Query(...)):
8489
try:
8590
if isinstance(vector_store, AsyncPgVector):
86-
existing_ids = await vector_store.get_filtered_ids(ids, executor=request.app.state.thread_pool)
87-
documents = await vector_store.get_documents_by_ids(ids, executor=request.app.state.thread_pool)
91+
existing_ids = await vector_store.get_filtered_ids(
92+
ids, executor=request.app.state.thread_pool
93+
)
94+
documents = await vector_store.get_documents_by_ids(
95+
ids, executor=request.app.state.thread_pool
96+
)
8897
else:
8998
existing_ids = vector_store.get_filtered_ids(ids)
9099
documents = vector_store.get_documents_by_ids(ids)
@@ -121,8 +130,12 @@ async def get_documents_by_ids(request: Request, ids: list[str] = Query(...)):
121130
async def delete_documents(request: Request, document_ids: List[str] = Body(...)):
122131
try:
123132
if isinstance(vector_store, AsyncPgVector):
124-
existing_ids = await vector_store.get_filtered_ids(document_ids, executor=request.app.state.thread_pool)
125-
await vector_store.delete(ids=document_ids, executor=request.app.state.thread_pool)
133+
existing_ids = await vector_store.get_filtered_ids(
134+
document_ids, executor=request.app.state.thread_pool
135+
)
136+
await vector_store.delete(
137+
ids=document_ids, executor=request.app.state.thread_pool
138+
)
126139
else:
127140
existing_ids = vector_store.get_filtered_ids(document_ids)
128141
vector_store.delete(ids=document_ids)
@@ -179,7 +192,7 @@ async def query_embeddings_by_file_id(
179192
embedding,
180193
k=body.k,
181194
filter={"file_id": body.file_id},
182-
executor=request.app.state.thread_pool
195+
executor=request.app.state.thread_pool,
183196
)
184197
else:
185198
documents = vector_store.similarity_search_with_score_by_vector(
@@ -245,7 +258,7 @@ async def store_data_in_vector_db(
245258
file_id: str,
246259
user_id: str = "",
247260
clean_content: bool = False,
248-
executor = None,
261+
executor=None,
249262
) -> bool:
250263
text_splitter = RecursiveCharacterTextSplitter(
251264
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
@@ -313,8 +326,16 @@ async def embed_local_file(
313326
document.filename, document.file_content_type, document.filepath
314327
)
315328
data = await run_in_executor(request.app.state.thread_pool, loader.load)
329+
330+
# Clean up temporary UTF-8 file if it was created for encoding conversion
331+
cleanup_temp_encoding_file(loader)
332+
316333
result = await store_data_in_vector_db(
317-
data, document.file_id, user_id, clean_content=file_ext == "pdf", executor=request.app.state.thread_pool
334+
data,
335+
document.file_id,
336+
user_id,
337+
clean_content=file_ext == "pdf",
338+
executor=request.app.state.thread_pool,
318339
)
319340

320341
if result:
@@ -391,8 +412,16 @@ async def embed_file(
391412
file.filename, file.content_type, temp_file_path
392413
)
393414
data = await run_in_executor(request.app.state.thread_pool, loader.load)
415+
416+
# Clean up temporary UTF-8 file if it was created for encoding conversion
417+
cleanup_temp_encoding_file(loader)
418+
394419
result = await store_data_in_vector_db(
395-
data=data, file_id=file_id, user_id=user_id, clean_content=file_ext == "pdf", executor=request.app.state.thread_pool
420+
data=data,
421+
file_id=file_id,
422+
user_id=user_id,
423+
clean_content=file_ext == "pdf",
424+
executor=request.app.state.thread_pool,
396425
)
397426

398427
if not result:
@@ -458,8 +487,12 @@ async def load_document_context(request: Request, id: str):
458487
ids = [id]
459488
try:
460489
if isinstance(vector_store, AsyncPgVector):
461-
existing_ids = await vector_store.get_filtered_ids(ids, executor=request.app.state.thread_pool)
462-
documents = await vector_store.get_documents_by_ids(ids, executor=request.app.state.thread_pool)
490+
existing_ids = await vector_store.get_filtered_ids(
491+
ids, executor=request.app.state.thread_pool
492+
)
493+
documents = await vector_store.get_documents_by_ids(
494+
ids, executor=request.app.state.thread_pool
495+
)
463496
else:
464497
existing_ids = vector_store.get_filtered_ids(ids)
465498
documents = vector_store.get_documents_by_ids(ids)
@@ -526,8 +559,16 @@ async def embed_file_upload(
526559
)
527560

528561
data = await run_in_executor(request.app.state.thread_pool, loader.load)
562+
563+
# Clean up temporary UTF-8 file if it was created for encoding conversion
564+
cleanup_temp_encoding_file(loader)
565+
529566
result = await store_data_in_vector_db(
530-
data, file_id, user_id, clean_content=file_ext == "pdf", executor=request.app.state.thread_pool
567+
data,
568+
file_id,
569+
user_id,
570+
clean_content=file_ext == "pdf",
571+
executor=request.app.state.thread_pool,
531572
)
532573

533574
if not result:
@@ -577,7 +618,7 @@ async def query_embeddings_by_file_ids(request: Request, body: QueryMultipleBody
577618
embedding,
578619
k=body.k,
579620
filter={"file_id": {"$in": body.file_ids}},
580-
executor=request.app.state.thread_pool
621+
executor=request.app.state.thread_pool,
581622
)
582623
else:
583624
documents = vector_store.similarity_search_with_score_by_vector(

app/utils/document_loader.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# app/utils/document_loader.py
2+
import os
3+
import codecs
4+
import tempfile
25
from typing import List, Optional
36

47
from langchain_core.documents import Document
58

6-
from app.config import known_source_ext, PDF_EXTRACT_IMAGES, CHUNK_OVERLAP
9+
from app.config import known_source_ext, PDF_EXTRACT_IMAGES, CHUNK_OVERLAP, logger
710
from langchain_community.document_loaders import (
811
TextLoader,
912
PyPDFLoader,
@@ -17,14 +20,83 @@
1720
UnstructuredPowerPointLoader,
1821
)
1922

23+
24+
def detect_file_encoding(filepath: str) -> str:
25+
"""
26+
Detect the encoding of a file by checking for BOM markers.
27+
Returns the detected encoding or 'utf-8' as default.
28+
"""
29+
with open(filepath, "rb") as f:
30+
raw = f.read(4)
31+
32+
# Check for BOM markers
33+
if raw.startswith(codecs.BOM_UTF16_LE):
34+
return "utf-16-le"
35+
elif raw.startswith(codecs.BOM_UTF16_BE):
36+
return "utf-16-be"
37+
elif raw.startswith(codecs.BOM_UTF16):
38+
return "utf-16"
39+
elif raw.startswith(codecs.BOM_UTF8):
40+
return "utf-8-sig"
41+
elif raw.startswith(codecs.BOM_UTF32_LE):
42+
return "utf-32-le"
43+
elif raw.startswith(codecs.BOM_UTF32_BE):
44+
return "utf-32-be"
45+
else:
46+
# Default to utf-8 if no BOM is found
47+
return "utf-8"
48+
49+
50+
def cleanup_temp_encoding_file(loader) -> None:
51+
"""
52+
Clean up temporary UTF-8 file if it was created for encoding conversion.
53+
54+
:param loader: The document loader that may have created a temporary file
55+
"""
56+
if hasattr(loader, "_temp_filepath"):
57+
try:
58+
os.remove(loader._temp_filepath)
59+
except Exception as e:
60+
logger.warning(f"Failed to remove temporary UTF-8 file: {e}")
61+
62+
2063
def get_loader(filename: str, file_content_type: str, filepath: str):
2164
file_ext = filename.split(".")[-1].lower()
2265
known_type = True
2366

2467
if file_ext == "pdf":
2568
loader = PyPDFLoader(filepath, extract_images=PDF_EXTRACT_IMAGES)
2669
elif file_ext == "csv":
27-
loader = CSVLoader(filepath)
70+
# Detect encoding for CSV files
71+
encoding = detect_file_encoding(filepath)
72+
73+
if encoding != "utf-8":
74+
# For non-UTF-8 encodings, we need to convert the file first
75+
# Create a temporary UTF-8 file
76+
temp_file = None
77+
try:
78+
with tempfile.NamedTemporaryFile(
79+
mode="w", encoding="utf-8", suffix=".csv", delete=False
80+
) as temp_file:
81+
# Read the original file with detected encoding
82+
with open(filepath, "r", encoding=encoding) as original_file:
83+
content = original_file.read()
84+
temp_file.write(content)
85+
86+
temp_filepath = temp_file.name
87+
88+
# Use the temporary UTF-8 file with CSVLoader
89+
loader = CSVLoader(temp_filepath)
90+
91+
# Store the temp file path for cleanup
92+
loader._temp_filepath = temp_filepath
93+
except Exception as e:
94+
# If temp file was created but there was an error, clean it up
95+
if temp_file and os.path.exists(temp_file.name):
96+
os.unlink(temp_file.name)
97+
raise e
98+
else:
99+
loader = CSVLoader(filepath)
28100
elif file_ext == "rst":
29101
loader = UnstructuredRSTLoader(filepath, mode="elements")
30102
elif file_ext == "xml":
@@ -58,6 +130,7 @@ def get_loader(filename: str, file_content_type: str, filepath: str):
58130

59131
return loader, known_type, file_ext
60132

133+
61134
def clean_text(text: str) -> str:
62135
"""
63136
Remove NUL (0x00) characters from a string.
@@ -67,6 +140,7 @@ def clean_text(text: str) -> str:
67140
"""
68141
return text.replace("\x00", "")
69142

143+
70144
def process_documents(documents: List[Document]) -> str:
71145
processed_text = ""
72146
last_page: Optional[int] = None
@@ -91,4 +165,4 @@ def process_documents(documents: List[Document]) -> str:
91165
else:
92166
processed_text += new_content
93167

94-
return processed_text.strip()
168+
return processed_text.strip()

0 commit comments

Comments
 (0)