Skip to content

Commit 6547565

Browse files
committed
implement random search run mode
1 parent 21cc7db commit 6547565

File tree

1 file changed

+60
-49
lines changed

1 file changed

+60
-49
lines changed

main.py

Lines changed: 60 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from bcorag import option_picker as op
66
from bcorag.bcorag import BcoRag
77
from parameter_search.grid_search import BcoGridSearch
8+
from parameter_search.random_search import BcoRandomSearch
89
from bcorag.custom_types import (
910
GitFilter,
1011
GitFilters,
@@ -29,6 +30,56 @@ def main() -> None:
2930
options = parser.parse_args()
3031
run_mode = options.run_mode.lower().strip()
3132

33+
filenames = ["./bcorag/test_papers/High resolution measurement.pdf"]
34+
loaders = "SimpleDirectoryReader"
35+
chunking_config = [
36+
"1024 chunk size/20 chunk overlap",
37+
"2048 chunk size/50 chunk overlap",
38+
]
39+
embedding_model = "text-embedding-3-large"
40+
vector_store = "VectorStoreIndex"
41+
similarity_top_k = [2, 3, 4]
42+
llms = ["gpt-3.5-turbo", "gpt-4-turbo"]
43+
44+
github_url = "https://github.com/dpastling/plethora"
45+
git_info = misc_fns.extract_repo_data(github_url)
46+
if git_info is None:
47+
misc_fns.graceful_exit(1, "Error parsing github URL.")
48+
49+
git_filters: list[GitFilters] = []
50+
directory_filter = create_git_filters(
51+
filter_type=GithubRepositoryReader.FilterType.EXCLUDE,
52+
filter=GitFilter.DIRECTORY,
53+
value=["logs", "fastq", "data"],
54+
)
55+
git_filters.append(directory_filter)
56+
57+
file_ext_filter = create_git_filters(
58+
filter_type=GithubRepositoryReader.FilterType.EXCLUDE,
59+
filter=GitFilter.FILE_EXTENSION,
60+
value=["txt", "gz", "bed"],
61+
)
62+
git_filters.append(file_ext_filter)
63+
64+
git_data = create_git_data(
65+
user=git_info[0], repo=git_info[1], branch="master", filters=git_filters
66+
)
67+
68+
git_file_data = create_git_data_file_config(
69+
os.path.basename(filenames[0]), git_data
70+
)
71+
72+
search_space = init_search_space(
73+
filenames=filenames,
74+
loader=loaders,
75+
chunking_config=chunking_config,
76+
embedding_model=embedding_model,
77+
vector_store=vector_store,
78+
similarity_top_k=similarity_top_k,
79+
llm=llms,
80+
git_data=[git_file_data],
81+
)
82+
3283
match run_mode:
3384

3485
case "one-shot":
@@ -57,63 +108,23 @@ def main() -> None:
57108
"################################## RUN START ##################################"
58109
)
59110

60-
filenames = ["./bcorag/test_papers/High resolution measurement.pdf"]
61-
loaders = "SimpleDirectoryReader"
62-
chunking_config = [
63-
"1024 chunk size/20 chunk overlap",
64-
"2048 chunk size/50 chunk overlap",
65-
]
66-
embedding_model = "text-embedding-3-large"
67-
vector_store = "VectorStoreIndex"
68-
similarity_top_k = [2, 3, 4]
69-
llms = ["gpt-3.5-turbo", "gpt-4-turbo"]
70-
71-
github_url = "https://github.com/dpastling/plethora"
72-
git_info = misc_fns.extract_repo_data(github_url)
73-
if git_info is None:
74-
misc_fns.graceful_exit(1, "Error parsing github URL.")
75-
76-
git_filters: list[GitFilters] = []
77-
directory_filter = create_git_filters(
78-
filter_type=GithubRepositoryReader.FilterType.EXCLUDE,
79-
filter=GitFilter.DIRECTORY,
80-
value=["logs", "fastq", "data"],
81-
)
82-
git_filters.append(directory_filter)
83-
84-
file_ext_filter = create_git_filters(
85-
filter_type=GithubRepositoryReader.FilterType.EXCLUDE,
86-
filter=GitFilter.FILE_EXTENSION,
87-
value=["txt", "gz", "bed"],
88-
)
89-
git_filters.append(file_ext_filter)
111+
grid_search = BcoGridSearch(search_space)
112+
grid_search.train()
90113

91-
git_data = create_git_data(user=git_info[0], repo=git_info[1], branch="master", filters=git_filters)
114+
misc_fns.graceful_exit()
92115

93-
git_file_data = create_git_data_file_config(
94-
os.path.basename(filenames[0]), git_data
95-
)
116+
case "random-search":
96117

97-
search_space = init_search_space(
98-
filenames=filenames,
99-
loader=loaders,
100-
chunking_config=chunking_config,
101-
embedding_model=embedding_model,
102-
vector_store=vector_store,
103-
similarity_top_k=similarity_top_k,
104-
llm=llms,
105-
git_data=[git_file_data],
118+
logger = misc_fns.setup_root_logger("./logs/random-search.log")
119+
logger.info(
120+
"################################## RUN START ##################################"
106121
)
107122

108-
grid_search = BcoGridSearch(search_space)
109-
grid_search.train()
123+
random_search = BcoRandomSearch(search_space, subset_size=5)
124+
random_search.train()
110125

111126
misc_fns.graceful_exit()
112127

113-
case "random-search":
114-
# TODO : implement
115-
pass
116-
117128
case _:
118129

119130
misc_fns.graceful_exit(1, "Unsupported run mode.")

0 commit comments

Comments
 (0)