Skip to content

Commit 57d0e3d

Browse files
committed
feat: 데이터베이스 정보 수집 과정 중 병렬 처리 기능 추가
1 parent 201a6c7 commit 57d0e3d

File tree

1 file changed

+106
-22
lines changed

1 file changed

+106
-22
lines changed

llm_utils/tools.py

Lines changed: 106 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,40 @@
11
import os
2-
from typing import List, Dict
2+
from typing import List, Dict, Optional, TypeVar, Callable, Iterable, Any
33

44
from langchain.schema import Document
55

66
from data_utils.datahub_source import DatahubMetadataFetcher
7+
from tqdm import tqdm
8+
from concurrent.futures import ThreadPoolExecutor
9+
10+
T = TypeVar("T")
11+
R = TypeVar("R")
12+
13+
14+
def parallel_process[T, R](
15+
items: Iterable[T],
16+
process_fn: Callable[[T], R],
17+
max_workers: int = 8,
18+
desc: Optional[str] = None,
19+
show_progress: bool = True,
20+
) -> List[R]:
21+
"""병렬 처리를 위한 유틸리티 함수
22+
23+
Args:
24+
items (Iterable[T]): 처리할 아이템들
25+
process_fn (Callable[[T], R]): 각 아이템을 처리할 함수
26+
max_workers (int, optional): 최대 쓰레드 수. Defaults to 8.
27+
desc (Optional[str], optional): 진행 상태 메시지. Defaults to None.
28+
show_progress (bool, optional): 진행 상태 표시 여부. Defaults to True.
29+
30+
Returns:
31+
List[R]: 처리 결과 리스트
32+
"""
33+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
34+
futures = [executor.submit(process_fn, item) for item in items]
35+
if show_progress:
36+
futures = tqdm(futures, desc=desc)
37+
return [future.result() for future in futures]
738

839

940
def set_gms_server(gms_server: str):
@@ -22,47 +53,100 @@ def _get_fetcher():
2253
return DatahubMetadataFetcher(gms_server=gms_server)
2354

2455

25-
def _get_table_info() -> Dict[str, str]:
26-
"""전체 테이블 이름과 설명을 가져오는 함수"""
56+
def _process_urn(urn: str, fetcher: DatahubMetadataFetcher) -> tuple[str, str]:
57+
table_name = fetcher.get_table_name(urn)
58+
table_description = fetcher.get_table_description(urn)
59+
return (table_name, table_description)
60+
61+
62+
def _process_column_info(
63+
urn: str, table_name: str, fetcher: DatahubMetadataFetcher
64+
) -> Optional[List[Dict[str, str]]]:
65+
if fetcher.get_table_name(urn) == table_name:
66+
return fetcher.get_column_names_and_descriptions(urn)
67+
return None
68+
69+
70+
def _get_table_info(max_workers: int = 8) -> Dict[str, str]:
71+
"""전체 테이블 이름과 설명을 가져오는 함수
72+
73+
Args:
74+
max_workers (int, optional): 병렬 처리에 사용할 최대 쓰레드 수. Defaults to 8.
75+
76+
Returns:
77+
Dict[str, str]: 테이블 이름과 설명을 담은 딕셔너리
78+
"""
2779
fetcher = _get_fetcher()
2880
urns = fetcher.get_urns()
2981
table_info = {}
30-
for urn in urns:
31-
table_name = fetcher.get_table_name(urn)
32-
table_description = fetcher.get_table_description(urn)
82+
83+
results = parallel_process(
84+
urns,
85+
lambda urn: _process_urn(urn, fetcher),
86+
max_workers=max_workers,
87+
desc="테이블 정보 수집 중",
88+
)
89+
90+
for table_name, table_description in results:
3391
if table_name and table_description:
3492
table_info[table_name] = table_description
93+
3594
return table_info
3695

3796

38-
def _get_column_info(table_name: str) -> List[Dict[str, str]]:
39-
"""table_name에 해당하는 컬럼 이름과 설명을 가져오는 함수"""
97+
def _get_column_info(table_name: str, max_workers: int = 8) -> List[Dict[str, str]]:
98+
"""table_name에 해당하는 컬럼 이름과 설명을 가져오는 함수
99+
100+
Args:
101+
table_name (str): 테이블 이름
102+
max_workers (int, optional): 병렬 처리에 사용할 최대 쓰레드 수. Defaults to 8.
103+
104+
Returns:
105+
List[Dict[str, str]]: 컬럼 정보 리스트
106+
"""
40107
fetcher = _get_fetcher()
41108
urns = fetcher.get_urns()
42-
for urn in urns:
43-
if fetcher.get_table_name(urn) == table_name:
44-
return fetcher.get_column_names_and_descriptions(urn)
109+
110+
results = parallel_process(
111+
urns,
112+
lambda urn: _process_column_info(urn, table_name, fetcher),
113+
max_workers=max_workers,
114+
show_progress=False,
115+
)
116+
117+
for result in results:
118+
if result:
119+
return result
45120
return []
46121

47122

48-
def get_info_from_db() -> List[Document]:
49-
"""
50-
전체 테이블 이름과 설명, 컬럼 이름과 설명을 가져오는 함수
123+
def get_info_from_db(max_workers: int = 8) -> List[Document]:
124+
"""전체 테이블 이름과 설명, 컬럼 이름과 설명을 가져오는 함수
125+
126+
Args:
127+
max_workers (int, optional): 병렬 처리에 사용할 최대 쓰레드 수. Defaults to 8.
128+
129+
Returns:
130+
List[Document]: 테이블과 컬럼 정보를 담은 Document 객체 리스트
51131
"""
132+
table_info = _get_table_info(max_workers=max_workers)
52133

53-
table_info_str_list = []
54-
table_info = _get_table_info()
55-
for table_name, table_description in table_info.items():
56-
column_info = _get_column_info(table_name)
134+
def process_table_info(item: tuple[str, str]) -> str:
135+
table_name, table_description = item
136+
column_info = _get_column_info(table_name, max_workers=max_workers)
57137
column_info_str = "\n".join(
58138
[
59139
f"{col['column_name']}: {col['column_description']}"
60140
for col in column_info
61141
]
62142
)
63-
table_info_str_list.append(
64-
f"{table_name}: {table_description}\nColumns:\n {column_info_str}"
65-
)
143+
return f"{table_name}: {table_description}\nColumns:\n {column_info_str}"
144+
145+
table_info_str_list = parallel_process(
146+
table_info.items(),
147+
process_table_info,
148+
max_workers=max_workers,
149+
desc="컬럼 정보 수집 중",
150+
)
66151

67-
# table_info_str_list를 Document 객체 리스트로 변환
68152
return [Document(page_content=info) for info in table_info_str_list]

0 commit comments

Comments
 (0)