Skip to content

Commit d3e4326

Browse files
committed
모듈 경로를 tools 와 llm direcotry로 구분 #134
1 parent a2a10fe commit d3e4326

File tree

9 files changed

+59
-60
lines changed

9 files changed

+59
-60
lines changed

llm_utils/chains.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
)
77
from pydantic import BaseModel, Field
88

9-
from .llm_factory import get_llm
9+
from llm_utils.llm import get_llm
1010

1111
from prompt.template_loader import get_prompt_template
1212

llm_utils/graph_utils/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from langgraph.graph.message import add_messages
77
from langchain.chains.sql_database.prompt import SQL_PROMPTS
88
from pydantic import BaseModel, Field
9-
from llm_utils.llm_factory import get_llm
9+
from llm_utils.llm import get_llm
1010

1111
from llm_utils.chains import (
1212
query_refiner_chain,

llm_utils/llm/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from .factory import (
2+
get_llm,
3+
get_llm_openai,
4+
get_llm_azure,
5+
get_llm_bedrock,
6+
get_llm_gemini,
7+
get_llm_ollama,
8+
get_llm_huggingface,
9+
get_embeddings,
10+
get_embeddings_openai,
11+
get_embeddings_azure,
12+
get_embeddings_bedrock,
13+
get_embeddings_gemini,
14+
get_embeddings_ollama,
15+
get_embeddings_huggingface,
16+
)
17+
18+
__all__ = [
19+
"get_llm",
20+
"get_llm_openai",
21+
"get_llm_azure",
22+
"get_llm_bedrock",
23+
"get_llm_gemini",
24+
"get_llm_ollama",
25+
"get_llm_huggingface",
26+
"get_embeddings",
27+
"get_embeddings_openai",
28+
"get_embeddings_azure",
29+
"get_embeddings_bedrock",
30+
"get_embeddings_gemini",
31+
"get_embeddings_ollama",
32+
"get_embeddings_huggingface",
33+
]
34+
35+
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# llm_factory.py
21
import os
32
from typing import Optional
43

@@ -180,3 +179,5 @@ def get_embeddings_huggingface() -> BaseLanguageModel:
180179
repo_id=os.getenv("HUGGING_FACE_EMBEDDING_REPO_ID"),
181180
huggingfacehub_api_token=os.getenv("HUGGING_FACE_EMBEDDING_API_TOKEN"),
182181
)
182+
183+

llm_utils/retrieval.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
77
from transformers import AutoModelForSequenceClassification, AutoTokenizer
88

9-
from .tools import get_info_from_db
10-
from .llm_factory import get_embeddings
11-
from .vectordb import get_vector_db
9+
from llm_utils.vectordb import get_vector_db
1210

1311

1412
def load_reranker_model(device: str = "cpu"):

llm_utils/tools/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .datahub import (
2+
set_gms_server,
3+
get_info_from_db,
4+
get_metadata_from_db,
5+
)
6+
7+
__all__ = [
8+
"set_gms_server",
9+
"get_info_from_db",
10+
"get_metadata_from_db",
11+
]
12+
13+
Lines changed: 4 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,7 @@ def parallel_process(
1818
desc: Optional[str] = None,
1919
show_progress: bool = True,
2020
) -> List[R]:
21-
"""병렬 처리를 위한 유틸리티 함수
22-
23-
Args:
24-
items (Iterable[T]): 처리할 아이템들
25-
process_fn (Callable[[T], R]): 각 아이템을 처리할 함수
26-
max_workers (int, optional): 최대 쓰레드 수. Defaults to 8.
27-
desc (Optional[str], optional): 진행 상태 메시지. Defaults to None.
28-
show_progress (bool, optional): 진행 상태 표시 여부. Defaults to True.
29-
30-
Returns:
31-
List[R]: 처리 결과 리스트
32-
"""
21+
"""병렬 처리를 위한 유틸리티 함수"""
3322
with ThreadPoolExecutor(max_workers=max_workers) as executor:
3423
futures = [executor.submit(process_fn, item) for item in items]
3524
if show_progress:
@@ -67,14 +56,6 @@ def _process_column_info(
6756

6857

6958
def _get_table_info(max_workers: int = 8) -> Dict[str, str]:
70-
"""전체 테이블 이름과 설명을 가져오는 함수
71-
72-
Args:
73-
max_workers (int, optional): 병렬 처리에 사용할 최대 쓰레드 수. Defaults to 8.
74-
75-
Returns:
76-
Dict[str, str]: 테이블 이름과 설명을 담은 딕셔너리
77-
"""
7859
fetcher = _get_fetcher()
7960
urns = fetcher.get_urns()
8061
table_info = {}
@@ -96,40 +77,19 @@ def _get_table_info(max_workers: int = 8) -> Dict[str, str]:
9677
def _get_column_info(
9778
table_name: str, urn_table_mapping: Dict[str, str], max_workers: int = 8
9879
) -> List[Dict[str, str]]:
99-
"""table_name에 해당하는 컬럼 이름과 설명을 가져오는 함수
100-
101-
Args:
102-
table_name (str): 테이블 이름
103-
urn_table_mapping (Dict[str, str]): URN-테이블명 매핑 딕셔너리
104-
max_workers (int, optional): 병렬 처리에 사용할 최대 쓰레드 수. Defaults to 8.
105-
106-
Returns:
107-
List[Dict[str, str]]: 컬럼 정보 리스트
108-
"""
109-
# 해당 테이블의 URN 직접 찾기
11080
target_urn = urn_table_mapping.get(table_name)
11181
if not target_urn:
11282
return []
11383

114-
# Fetcher 생성 및 컬럼 정보 가져오기
11584
fetcher = _get_fetcher()
11685
column_info = fetcher.get_column_names_and_descriptions(target_urn)
11786

11887
return column_info
11988

12089

12190
def get_info_from_db(max_workers: int = 8) -> List[Document]:
122-
"""전체 테이블 이름과 설명, 컬럼 이름과 설명을 가져오는 함수
123-
124-
Args:
125-
max_workers (int, optional): 병렬 처리에 사용할 최대 쓰레드 수. Defaults to 8.
126-
127-
Returns:
128-
List[Document]: 테이블과 컬럼 정보를 담은 Document 객체 리스트
129-
"""
13091
table_info = _get_table_info(max_workers=max_workers)
13192

132-
# URN-테이블명 매핑을 한 번만 생성
13393
fetcher = _get_fetcher()
13494
urns = list(fetcher.get_urns())
13595
urn_table_mapping = {}
@@ -142,10 +102,8 @@ def process_table_info(item: tuple[str, str]) -> str:
142102
table_name, table_description = item
143103
urn = urn_table_mapping.get(table_name, "")
144104

145-
# fetcher 인스턴스 생성
146105
local_fetcher = _get_fetcher()
147106

148-
# 컬럼 정보 가져오기
149107
column_info = _get_column_info(
150108
table_name, urn_table_mapping, max_workers=max_workers
151109
)
@@ -156,13 +114,11 @@ def process_table_info(item: tuple[str, str]) -> str:
156114
]
157115
)
158116

159-
# 쿼리 및 용어집 정보 가져오기
160117
queries_result = local_fetcher.get_queries_by_urn(urn) if urn else {}
161118
glossary_terms_result = (
162119
local_fetcher.get_glossary_terms_by_urn(urn) if urn else {}
163120
)
164121

165-
# GraphQL 응답에서 실제 쿼리 리스트 추출
166122
queries = []
167123
if (
168124
queries_result
@@ -172,7 +128,6 @@ def process_table_info(item: tuple[str, str]) -> str:
172128
):
173129
queries = queries_result["data"]["listQueries"]["queries"]
174130

175-
# GraphQL 응답에서 실제 glossary terms 추출
176131
glossary_terms = []
177132
if (
178133
glossary_terms_result
@@ -199,10 +154,9 @@ def process_table_info(item: tuple[str, str]) -> str:
199154
}
200155
)
201156

202-
# 쿼리 정보를 name, description, statement.value만 추출하여 포맷
203157
if queries:
204158
formatted_queries = []
205-
for q in queries[:3]: # 최대 3개 쿼리만
159+
for q in queries[:3]:
206160
if isinstance(q, dict) and "properties" in q:
207161
props = q["properties"]
208162
name = props.get("name", "No name")
@@ -241,10 +195,6 @@ def process_table_info(item: tuple[str, str]) -> str:
241195

242196

243197
def get_metadata_from_db() -> List[Dict]:
244-
"""
245-
전체 테이블의 메타데이터(테이블 이름, 설명, 컬럼 이름, 설명, 테이블 lineage, 컬럼 별 lineage)를 가져오는 함수
246-
"""
247-
248198
fetcher = _get_fetcher()
249199
urns = list(fetcher.get_urns())
250200

@@ -256,3 +206,5 @@ def get_metadata_from_db() -> List[Dict]:
256206
metadata.append(table_metadata)
257207

258208
return metadata
209+
210+

llm_utils/vectordb/faiss_db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Optional
88

99
from llm_utils.tools import get_info_from_db
10-
from llm_utils.llm_factory import get_embeddings
10+
from llm_utils.llm import get_embeddings
1111

1212

1313
def get_faiss_vector_db(vectordb_path: Optional[str] = None):

llm_utils/vectordb/pgvector_db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from langchain_postgres.vectorstores import PGVector
1010

1111
from llm_utils.tools import get_info_from_db
12-
from llm_utils.llm_factory import get_embeddings
12+
from llm_utils.llm import get_embeddings
1313

1414

1515
def _check_collection_exists(connection_string: str, collection_name: str) -> bool:

0 commit comments

Comments
 (0)