@@ -103,9 +103,8 @@ def __init__(self, verb_weight=1.0, object_weight=1.0):
103103 self .object_weight = object_weight
104104
105105 def get_dense_data_array (self , data : List ) -> np .ndarray :
106- tfidf_vectorizer = TfidfVectorizer ()
107- tf_idf_data_vector = tfidf_vectorizer .fit_transform (data )
108- return tf_idf_data_vector .toarray (), tfidf_vectorizer
106+ tf_idf_data_vector = self .vectorizer .fit_transform (data )
107+ return tf_idf_data_vector .toarray ()
109108
110109 def compute_affinity (self ,
111110 application_name ,
@@ -120,7 +119,7 @@ def compute_affinity(self,
120119 self .object_weight = object_weight
121120
122121 print ("Converting data to dense TF-IDF vectors..." )
123- dense_data_array , tfidf_vectorizer = self .get_dense_data_array (labels )
122+ dense_data_array = self .get_dense_data_array (labels )
124123
125124 zero_vectors = np .all (dense_data_array == 0 , axis = 1 )
126125 print (f"Number of zero vectors: { np .sum (zero_vectors )} " )
@@ -134,20 +133,25 @@ def compute_affinity(self,
134133 return None
135134
136135 print ("Ponderating TF-IDF embeddings with verb and object weights..." )
137- # Adjust TF-IDF values based on verb and object weights
138136 dense_data_array = Utils .ponderate_tfidf_with_weights (
139- labels , # batch_data
140- dense_data_array , # tfidf_matrix
141- tfidf_vectorizer , # vectorizer
137+ labels ,
138+ dense_data_array ,
139+ self . vectorizer ,
142140 verb_weight = self .verb_weight ,
143141 object_weight = self .object_weight
144142 )
145143
146144 print ("Performing Agglomerative Clustering..." )
147- clustering_model = AgglomerativeClustering (n_clusters = None ,
148- linkage = linkage ,
149- distance_threshold = distance_threshold ,
150- metric = metric )
145+ if linkage == "ward" :
146+ clustering_model = AgglomerativeClustering (n_clusters = None ,
147+ linkage = linkage ,
148+ distance_threshold = distance_threshold )
149+ else :
150+ clustering_model = AgglomerativeClustering (n_clusters = None ,
151+ linkage = linkage ,
152+ distance_threshold = distance_threshold ,
153+ metric = metric )
154+
151155 clustering_model .fit (dense_data_array )
152156
153157 return Utils .generate_pkl (application_name ,
0 commit comments