5
5
import numpy as np
6
6
from cache import AsyncTTL
7
7
from tqdm import tqdm
8
- import os
8
+ import os
9
9
10
10
11
11
class Model :
12
- def __new__ (cls , context ):
13
- cls .context = context
14
- if not hasattr (cls , 'instance' ):
15
- files = os .listdir ("./content" )
16
- cls .df = pd .read_csv (os .path .join ("./content" , files [0 ]))
17
- cls .idf_dict = cls ._Model__compute_idf (cls .df )
18
- cls .instance = super (Model , cls ).__new__ (cls )
19
- return cls .instance
12
+ def __init__ (self , seed_df ,pesticide_df , fertilizer_df , global_df , request : ModelRequest , search_categoty = 'others' ):
13
+ self .search_category = request .search_category
14
+ if self .search_category == 'seed' :
15
+ self .df = seed_df
16
+ elif self .search_category == 'fertilizer' :
17
+ self .df = fertilizer_df
18
+ elif self .search_category == 'pesticide' :
19
+ self .df = pesticide_df
20
+ else :
21
+ self .df = global_df
22
+ self .idf_dict = self .__compute_idf (self .df )
20
23
21
24
@staticmethod
22
25
def __compute_idf (df ):
23
26
N = len (df )
24
- all_tags = df ['tags' ].str .lower (). str . split ().explode ()
27
+ all_tags = df ['tags' ].str .split ().explode ()
25
28
df_count_series = all_tags .drop_duplicates ().value_counts ()
26
29
idf_dict = {tag : log (N / (df_count + 1 )) for tag , df_count in df_count_series .items ()}
27
30
return idf_dict
28
-
31
+
29
32
def __fuzzy_match (self , query_tokens , doc_tokens ):
30
33
weighted_fuzzy_scores = []
31
34
query_set = set (query_tokens )
@@ -40,7 +43,8 @@ def __fuzzy_match(self, query_tokens, doc_tokens):
40
43
max_ratio = ratio
41
44
max_token = token
42
45
43
- idf_weight = self .idf_dict .get (max_token )
46
+
47
+ idf_weight = self .idf_dict .get (max_token , 0.0 )
44
48
weighted_fuzzy_scores .append ((max_ratio / 100 ) * idf_weight )
45
49
46
50
return np .mean (weighted_fuzzy_scores )
@@ -50,11 +54,11 @@ def __fuzzy_match(self, query_tokens, doc_tokens):
50
54
async def inference (self , request : ModelRequest ):
51
55
scores = []
52
56
query = request .query
53
- n = request .n
57
+ n = int ( request .n )
54
58
query_tokens = query .lower ().split ()
55
59
56
60
for _ , row in tqdm (self .df .iterrows ()):
57
- doc_tokens = row ['tags' ]. lower ( ).split ()
61
+ doc_tokens = str ( row ['tags' ]).split ()
58
62
fuzzy_score = self .__fuzzy_match (query_tokens , doc_tokens )
59
63
scores .append (fuzzy_score )
60
64
0 commit comments