Skip to content

Commit 1010ef6

Browse files
SirDegrafJWittmeyer
andauthored
Recommened downloaded models (#87)
* Adds logic to return user provided models too * Updates submodule * Language to n/a * Adds check for existing * Rename and restructure for performance Co-authored-by: JWittmeyer <[email protected]>
1 parent 2b63dce commit 1010ef6

File tree

3 files changed

+68
-3
lines changed

3 files changed

+68
-3
lines changed

controller/embedding/manager.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,36 @@
55
from util import daemon
66
from . import util
77
from . import connector
8+
from controller.model_provider import manager as model_manager
89

910

1011
def get_recommended_encoders() -> List[Any]:
11-
return connector.request_listing_recommended_encoders()
12+
recommendations = connector.request_listing_recommended_encoders()
13+
existing_models = model_manager.get_model_provider_info()
14+
for model in existing_models:
15+
16+
if not model["zero_shot_pipeline"]:
17+
not_yet_known = (
18+
len(
19+
list(
20+
filter(
21+
lambda rec: rec["config_string"] == model["name"],
22+
recommendations,
23+
)
24+
)
25+
)
26+
== 0
27+
)
28+
if not_yet_known:
29+
recommendations.append(
30+
{
31+
"config_string": model["name"],
32+
"description": "User downloaded model",
33+
"tokenizers": ["all"],
34+
"applicability": {"attribute": True, "token": True},
35+
}
36+
)
37+
return recommendations
1238

1339

1440
def create_attribute_level_embedding(

controller/model_provider/__init__.py

Whitespace-only changes.

controller/zero_shot/manager.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from util import daemon
2424
from controller.weak_supervision import weak_supervision_service as weak_supervision
25+
from controller.model_provider import manager as model_manager
2526

2627

2728
def get_zero_shot_text(
@@ -46,15 +47,44 @@ def get_zero_shot_recommendations(
4647
project_id: Optional[str] = None,
4748
) -> List[Dict[str, str]]:
4849
recommendations = zs_service.get_recommended_models()
50+
existing_models = model_manager.get_model_provider_info()
51+
52+
for model in existing_models:
53+
if model["zero_shot_pipeline"]:
54+
not_existing_yet = (
55+
len(
56+
list(
57+
filter(
58+
lambda rec: rec["configString"] == model["name"],
59+
recommendations,
60+
)
61+
)
62+
)
63+
== 0
64+
)
65+
if not_existing_yet:
66+
recommendations.append(
67+
{
68+
"configString": model["name"],
69+
"avgTime": "n/a",
70+
"language": "n/a",
71+
"link": model["link"],
72+
"base": "n/a",
73+
"size": __format_size_string(model["size"]),
74+
"prio": 1,
75+
}
76+
)
77+
4978
if not project_id:
5079
return recommendations
5180

5281
project_item = project.get(project_id)
5382
if project_item and project_item.tokenizer_blank:
5483
recommendations = [
55-
r for r in recommendations if r["language"] == project_item.tokenizer_blank
84+
r
85+
for r in recommendations
86+
if r["language"] == project_item.tokenizer_blank or r["language"] == "n/a"
5687
]
57-
5888
return recommendations
5989

6090

@@ -199,3 +229,12 @@ def cancel_zero_shot_run(
199229
# setting the state to failed with be noted by the thread in zs service and handled
200230
item.state = enums.PayloadState.FAILED.value
201231
general.commit()
232+
233+
234+
def __format_size_string(size: int) -> str:
235+
size_in_mb = int(size / 1048576)
236+
237+
if size_in_mb < 1024:
238+
return str(size_in_mb) + " MB"
239+
else:
240+
return str(size_in_mb / 1024) + " GB"

0 commit comments

Comments
 (0)