@@ -26,7 +26,7 @@ def get_vector_db():
2626 return db
2727
2828
29- def load_reranker_model ():
29+ def load_reranker_model (device : str = "cpu" ):
3030 """한국어 reranker 모델을 로드하거나 다운로드합니다."""
3131 local_model_path = os .path .join (os .getcwd (), "ko_reranker_local" )
3232
@@ -36,27 +36,31 @@ def load_reranker_model():
3636 else :
3737 print ("⬇️ ko-reranker 모델 다운로드 및 저장 중..." )
3838 model = AutoModelForSequenceClassification .from_pretrained (
39- "Dongjin-kr/ko-reranker" ,
39+ "Dongjin-kr/ko-reranker"
4040 )
4141 tokenizer = AutoTokenizer .from_pretrained ("Dongjin-kr/ko-reranker" )
4242 model .save_pretrained (local_model_path )
4343 tokenizer .save_pretrained (local_model_path )
4444
45- return HuggingFaceCrossEncoder (model_name = local_model_path )
45+ return HuggingFaceCrossEncoder (
46+ model_name = local_model_path ,
47+ model_kwargs = {"device" : device },
48+ )
4649
4750
48- def get_retriever (retriever_name : str = "기본" , top_n : int = 5 ):
51+ def get_retriever (retriever_name : str = "기본" , top_n : int = 5 , device : str = "cpu" ):
4952 """검색기 타입에 따라 적절한 검색기를 생성합니다.
5053
5154 Args:
5255 retriever_name: 사용할 검색기 이름 ("기본", "재순위", 등)
5356 top_n: 반환할 상위 결과 개수
5457 """
58+ print (device )
5559 retrievers = {
5660 "기본" : lambda : get_vector_db ().as_retriever (search_kwargs = {"k" : top_n }),
5761 "Reranker" : lambda : ContextualCompressionRetriever (
5862 base_compressor = CrossEncoderReranker (
59- model = load_reranker_model (), top_n = top_n
63+ model = load_reranker_model (device ), top_n = top_n
6064 ),
6165 base_retriever = get_vector_db ().as_retriever (search_kwargs = {"k" : top_n }),
6266 ),
@@ -71,13 +75,17 @@ def get_retriever(retriever_name: str = "기본", top_n: int = 5):
7175 return retrievers [retriever_name ]()
7276
7377
74- def search_tables (query : str , retriever_name : str = "기본" , top_n : int = 5 ):
78+ def search_tables (
79+ query : str , retriever_name : str = "기본" , top_n : int = 5 , device : str = "cpu"
80+ ):
7581 """쿼리에 맞는 테이블 정보를 검색합니다."""
7682 if retriever_name == "기본" :
7783 db = get_vector_db ()
7884 doc_res = db .similarity_search (query , k = top_n )
7985 else :
80- retriever = get_retriever (retriever_name = retriever_name , top_n = top_n )
86+ retriever = get_retriever (
87+ retriever_name = retriever_name , top_n = top_n , device = device
88+ )
8189 doc_res = retriever .invoke (query )
8290
8391 # 결과를 사전 형태로 변환
0 commit comments