Skip to content

Commit 21cc7db

Browse files
committed
add random search
1 parent b885348 commit 21cc7db

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

parameter_search/random_search.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""Random search class.
2+
"""
3+
4+
from .parameter_search import BcoParameterSearch
5+
from .custom_types import SearchSpace
6+
from bcorag.custom_types import UserSelections, create_git_data, create_user_selections
7+
from itertools import product
8+
import os
9+
import random
10+
11+
12+
class BcoRandomSearch(BcoParameterSearch):
13+
"""BCO random search class. Subclass of
14+
BcoParameterSearch.
15+
"""
16+
17+
def __init__(self, search_space: SearchSpace, subset_size: int = 5):
18+
"""Constructor.
19+
20+
Parameters
21+
----------
22+
search_space : SearchSpace
23+
The parameter search space.
24+
subset_size : int (default: 5)
25+
The number of parameter sets to search.
26+
"""
27+
super().__init__(search_space)
28+
self.subset_size = subset_size
29+
30+
def _create_param_sets(self) -> list[UserSelections]:
31+
"""Creates a random subset of the parameter space."""
32+
param_sets: list[UserSelections] = []
33+
34+
for (
35+
llm,
36+
embedding_model,
37+
filepath,
38+
loader,
39+
chunking_config,
40+
vector_store,
41+
similarity_top_k,
42+
) in product(
43+
self._llms,
44+
self._embedding_models,
45+
self._files,
46+
self._loaders,
47+
self._chunking_configs,
48+
self._vector_stores,
49+
self._similarity_top_k,
50+
):
51+
base_selections = {
52+
"llm": llm,
53+
"embedding_model": embedding_model,
54+
"filename": os.path.basename(str(filepath)),
55+
"filepath": filepath,
56+
"vector_store": vector_store,
57+
"loader": loader,
58+
"mode": "production",
59+
"similarity_top_k": similarity_top_k,
60+
"chunking_config": chunking_config,
61+
}
62+
63+
if self._git_data is None:
64+
base_selections["git_data"] = None
65+
else:
66+
for git_data in self._git_data:
67+
if git_data["filename"] == filepath or git_data[
68+
"filename"
69+
] == os.path.basename(str(filepath)):
70+
base_selections["git_data"] = create_git_data(
71+
user=git_data["git_info"]["user"],
72+
repo=git_data["git_info"]["repo"],
73+
branch=git_data["git_info"]["branch"],
74+
filters=git_data["git_info"]["filters"],
75+
)
76+
user_selections = create_user_selections(
77+
base_selections["llm"],
78+
base_selections["embedding_model"],
79+
base_selections["filename"],
80+
base_selections["filepath"],
81+
base_selections["vector_store"],
82+
base_selections["loader"],
83+
base_selections["mode"],
84+
base_selections["similarity_top_k"],
85+
base_selections["chunking_config"],
86+
base_selections["git_data"],
87+
)
88+
param_sets.append(user_selections)
89+
90+
if self.subset_size > len(param_sets):
91+
self.subset_size = len(param_sets)
92+
93+
param_subset = random.sample(param_sets, self.subset_size)
94+
95+
return param_subset

0 commit comments

Comments
 (0)