Skip to content

Commit 9656e67

Browse files
committed
feat: device 선택가능하도록 업데이트
1 parent 0b4e3f6 commit 9656e67

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed

interface/lang2sql.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def execute_query(
4848
database_env: str,
4949
retriever_name: str = "기본",
5050
top_n: int = 5,
51+
device: str = "cpu",
5152
) -> dict:
5253
"""
5354
Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다.
@@ -70,6 +71,7 @@ def execute_query(
7071
"best_practice_query": "",
7172
"retriever_name": retriever_name,
7273
"top_n": top_n,
74+
"device": device,
7375
}
7476
)
7577

@@ -129,6 +131,12 @@ def display_result(
129131
index=0,
130132
)
131133

134+
device = st.selectbox(
135+
"모델 실행 장치를 선택하세요:",
136+
options=["cpu", "cuda"],
137+
index=0,
138+
)
139+
132140
retriever_options = {
133141
"기본": "벡터 검색 (기본)",
134142
"Reranker": "Reranker 검색 (정확도 향상)",
@@ -160,5 +168,6 @@ def display_result(
160168
database_env=user_database_env,
161169
retriever_name=user_retriever,
162170
top_n=user_top_n,
171+
device=device,
163172
)
164173
display_result(res=result, database=db)

llm_utils/graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class QueryMakerState(TypedDict):
3434
generated_query: str
3535
retriever_name: str
3636
top_n: int
37+
device: str
3738

3839

3940
# 노드 함수: QUERY_REFINER 노드
@@ -57,6 +58,7 @@ def get_table_info_node(state: QueryMakerState):
5758
query=state["messages"][0].content,
5859
retriever_name=state["retriever_name"],
5960
top_n=state["top_n"],
61+
device=state["device"],
6062
)
6163
state["searched_tables"] = documents_dict
6264

llm_utils/retrieval.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def get_vector_db():
2626
return db
2727

2828

29-
def load_reranker_model():
29+
def load_reranker_model(device: str = "cpu"):
3030
"""한국어 reranker 모델을 로드하거나 다운로드합니다."""
3131
local_model_path = os.path.join(os.getcwd(), "ko_reranker_local")
3232

@@ -36,27 +36,31 @@ def load_reranker_model():
3636
else:
3737
print("⬇️ ko-reranker 모델 다운로드 및 저장 중...")
3838
model = AutoModelForSequenceClassification.from_pretrained(
39-
"Dongjin-kr/ko-reranker",
39+
"Dongjin-kr/ko-reranker"
4040
)
4141
tokenizer = AutoTokenizer.from_pretrained("Dongjin-kr/ko-reranker")
4242
model.save_pretrained(local_model_path)
4343
tokenizer.save_pretrained(local_model_path)
4444

45-
return HuggingFaceCrossEncoder(model_name=local_model_path)
45+
return HuggingFaceCrossEncoder(
46+
model_name=local_model_path,
47+
model_kwargs={"device": device},
48+
)
4649

4750

48-
def get_retriever(retriever_name: str = "기본", top_n: int = 5):
51+
def get_retriever(retriever_name: str = "기본", top_n: int = 5, device: str = "cpu"):
4952
"""검색기 타입에 따라 적절한 검색기를 생성합니다.
5053
5154
Args:
5255
retriever_name: 사용할 검색기 이름 ("기본", "재순위", 등)
5356
top_n: 반환할 상위 결과 개수
5457
"""
58+
print(device)
5559
retrievers = {
5660
"기본": lambda: get_vector_db().as_retriever(search_kwargs={"k": top_n}),
5761
"Reranker": lambda: ContextualCompressionRetriever(
5862
base_compressor=CrossEncoderReranker(
59-
model=load_reranker_model(), top_n=top_n
63+
model=load_reranker_model(device), top_n=top_n
6064
),
6165
base_retriever=get_vector_db().as_retriever(search_kwargs={"k": top_n}),
6266
),
@@ -71,13 +75,17 @@ def get_retriever(retriever_name: str = "기본", top_n: int = 5):
7175
return retrievers[retriever_name]()
7276

7377

74-
def search_tables(query: str, retriever_name: str = "기본", top_n: int = 5):
78+
def search_tables(
79+
query: str, retriever_name: str = "기본", top_n: int = 5, device: str = "cpu"
80+
):
7581
"""쿼리에 맞는 테이블 정보를 검색합니다."""
7682
if retriever_name == "기본":
7783
db = get_vector_db()
7884
doc_res = db.similarity_search(query, k=top_n)
7985
else:
80-
retriever = get_retriever(retriever_name=retriever_name, top_n=top_n)
86+
retriever = get_retriever(
87+
retriever_name=retriever_name, top_n=top_n, device=device
88+
)
8189
doc_res = retriever.invoke(query)
8290

8391
# 결과를 사전 형태로 변환

0 commit comments

Comments
 (0)