Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/embedders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def load_pca_weights(self, file_name: str):
def _reduce(
self,
documents: List[Union[str, Doc]],
as_generator: bool,
fit_model: bool,
fit_after_n_batches: int,
):
Expand All @@ -178,11 +179,11 @@ def _reduce_batch(
fit_after_n_batches: int,
) -> Union[List, Generator]:
if as_generator:
return self._reduce(documents, fit_model, fit_after_n_batches)
return self._reduce(documents, as_generator, fit_model, fit_after_n_batches)
else:
embeddings = []
for embedding_batch in self._reduce(
documents, fit_model, fit_after_n_batches
documents, as_generator, fit_model, fit_after_n_batches
):
embeddings.extend(embedding_batch)
return embeddings
Expand Down
10 changes: 7 additions & 3 deletions src/embedders/classification/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def _transform(
def _reduce(
self,
documents: List[Union[str, Doc]],
as_generator: bool,
fit_model: bool,
fit_after_n_batches: int,
) -> Generator[List[List[Union[float, int]]], None, None]:
Expand All @@ -29,7 +30,7 @@ def _reduce(
num_batches = util.num_batches(documents, self.embedder.batch_size)
fit_after_n_batches = min(num_batches, fit_after_n_batches) - 1
for batch_idx, batch in enumerate(
self.embedder.fit_transform(documents, as_generator=True)
self.embedder.fit_transform(documents, as_generator)
):
if batch_idx <= fit_after_n_batches:
embeddings_training.append(batch)
Expand All @@ -56,8 +57,11 @@ def _reduce(
if batch_idx > fit_after_n_batches:
yield self._transform(batch)
else:
embeddings = self.embedder.transform(documents)
yield self._transform(embeddings)
embeddings = self.embedder.transform(documents, as_generator)
if as_generator:
yield self._transform(list(embeddings))
else:
yield self._transform(embeddings)

@staticmethod
def load(embedder: dict) -> "PCASentenceReducer":
Expand Down
16 changes: 12 additions & 4 deletions src/embedders/extraction/reduce.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from spacy.tokens.doc import Doc
from typing import List, Generator, Union
import numpy as np
from src.embedders import PCAReducer, util
Expand All @@ -24,14 +25,18 @@ def _transform(
return batch_unsqueezed

def _reduce(
self, documents, fit_model, fit_after_n_batches
self,
documents: List[Union[str, Doc]],
as_generator: bool,
fit_model: bool,
fit_after_n_batches: int,
) -> Generator[List[List[List[Union[float, int]]]], None, None]:
if fit_model:
embeddings_training = []
num_batches = util.num_batches(documents, self.embedder.batch_size)
fit_after_n_batches = min(num_batches, fit_after_n_batches) - 1
for batch_idx, batch in enumerate(
self.embedder.fit_transform(documents, as_generator=True)
self.embedder.fit_transform(documents, as_generator)
):
if batch_idx <= fit_after_n_batches:
embeddings_training.append(batch)
Expand Down Expand Up @@ -60,5 +65,8 @@ def _reduce(
if batch_idx > fit_after_n_batches:
yield self._transform(batch)
else:
embeddings = self.embedder.transform(documents)
yield self._transform(embeddings)
embeddings = self.embedder.transform(documents, as_generator)
if as_generator:
yield self._transform(list(embeddings))
else:
yield self._transform(embeddings)