Skip to content

Commit 31fa6e8

Browse files
committed
fix typing errors
1 parent d2c095c commit 31fa6e8

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/autointent/_wrappers/embedder/hashing_vectorizer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def get_hash(self) -> int:
6767
hasher.update(self.config.norm if self.config.norm is not None else "None")
6868
hasher.update(self.config.binary)
6969
hasher.update(self.config.dtype)
70-
return hasher.hexdigest()
70+
return int(hasher.hexdigest(), 16)
7171

7272
@overload
7373
def embed(
@@ -97,7 +97,7 @@ def embed(
9797
"""
9898
# Transform texts to sparse matrix, then convert to dense
9999
embeddings_sparse = self._vectorizer.transform(utterances)
100-
embeddings = embeddings_sparse.toarray().astype(np.float32)
100+
embeddings: npt.NDArray[np.float32] = embeddings_sparse.toarray().astype(np.float32)
101101

102102
if return_tensors:
103103
return torch.from_numpy(embeddings)
@@ -115,7 +115,8 @@ def similarity(
115115
Returns:
116116
Similarity matrix with shape (n_samples, m_samples).
117117
"""
118-
return cosine_similarity(embeddings1, embeddings2).astype(np.float32)
118+
similarity_matrix: npt.NDArray[np.float32] = cosine_similarity(embeddings1, embeddings2).astype(np.float32)
119+
return similarity_matrix
119120

120121
def dump(self, path: Path) -> None:
121122
"""Save the backend state to disk.
@@ -157,7 +158,7 @@ def load(cls, path: Path) -> "HashingVectorizerEmbeddingBackend":
157158
logger.debug("Loaded HashingVectorizer backend from %s", path)
158159
return instance
159160

160-
def train(self, utterances: list[str], labels: list[int], config) -> None: # noqa: ANN001
161+
def train(self, utterances: list[str], labels: list[int], config) -> None: # noqa: ANN001 # type: ignore[no-untyped-def]
161162
"""Train the backend.
162163
163164
HashingVectorizer is stateless and doesn't support training.

0 commit comments

Comments
 (0)