Skip to content

Commit 37d5c38

Browse files
committed
fixes in tf-idf dendogram generation
1 parent 42df8c7 commit 37d5c38

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

backend/Affinity_strategy.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,8 @@ def __init__(self, verb_weight=1.0, object_weight=1.0):
103103
self.object_weight = object_weight
104104

105105
def get_dense_data_array(self, data: List) -> np.ndarray:
106-
tfidf_vectorizer = TfidfVectorizer()
107-
tf_idf_data_vector = tfidf_vectorizer.fit_transform(data)
108-
return tf_idf_data_vector.toarray(), tfidf_vectorizer
106+
tf_idf_data_vector = self.vectorizer.fit_transform(data)
107+
return tf_idf_data_vector.toarray()
109108

110109
def compute_affinity(self,
111110
application_name,
@@ -120,7 +119,7 @@ def compute_affinity(self,
120119
self.object_weight = object_weight
121120

122121
print("Converting data to dense TF-IDF vectors...")
123-
dense_data_array, tfidf_vectorizer = self.get_dense_data_array(labels)
122+
dense_data_array = self.get_dense_data_array(labels)
124123

125124
zero_vectors = np.all(dense_data_array == 0, axis=1)
126125
print(f"Number of zero vectors: {np.sum(zero_vectors)}")
@@ -134,20 +133,25 @@ def compute_affinity(self,
134133
return None
135134

136135
print("Ponderating TF-IDF embeddings with verb and object weights...")
137-
# Adjust TF-IDF values based on verb and object weights
138136
dense_data_array = Utils.ponderate_tfidf_with_weights(
139-
labels, # batch_data
140-
dense_data_array, # tfidf_matrix
141-
tfidf_vectorizer, # vectorizer
137+
labels,
138+
dense_data_array,
139+
self.vectorizer,
142140
verb_weight=self.verb_weight,
143141
object_weight=self.object_weight
144142
)
145143

146144
print("Performing Agglomerative Clustering...")
147-
clustering_model = AgglomerativeClustering(n_clusters=None,
148-
linkage=linkage,
149-
distance_threshold=distance_threshold,
150-
metric=metric)
145+
if linkage == "ward":
146+
clustering_model = AgglomerativeClustering(n_clusters=None,
147+
linkage=linkage,
148+
distance_threshold=distance_threshold)
149+
else:
150+
clustering_model = AgglomerativeClustering(n_clusters=None,
151+
linkage=linkage,
152+
distance_threshold=distance_threshold,
153+
metric=metric)
154+
151155
clustering_model.fit(dense_data_array)
152156

153157
return Utils.generate_pkl(application_name,

0 commit comments

Comments
 (0)