Skip to content

Commit 369b734

Browse files
fix: StopIteration issue (#169)
* fix: StopIteration issue * fix: PCAReducer._reduce generator * fix: revert as_generator hardcode in reducer fit_transform * fix: missing embeddings definition
1 parent ea6c692 commit 369b734

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

src/embedders/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def load_pca_weights(self, file_name: str):
165165
def _reduce(
166166
self,
167167
documents: List[Union[str, Doc]],
168+
as_generator: bool,
168169
fit_model: bool,
169170
fit_after_n_batches: int,
170171
):
@@ -178,11 +179,11 @@ def _reduce_batch(
178179
fit_after_n_batches: int,
179180
) -> Union[List, Generator]:
180181
if as_generator:
181-
return self._reduce(documents, fit_model, fit_after_n_batches)
182+
return self._reduce(documents, as_generator, fit_model, fit_after_n_batches)
182183
else:
183184
embeddings = []
184185
for embedding_batch in self._reduce(
185-
documents, fit_model, fit_after_n_batches
186+
documents, as_generator, fit_model, fit_after_n_batches
186187
):
187188
embeddings.extend(embedding_batch)
188189
return embeddings

src/embedders/classification/reduce.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def _transform(
2121
def _reduce(
2222
self,
2323
documents: List[Union[str, Doc]],
24+
as_generator: bool,
2425
fit_model: bool,
2526
fit_after_n_batches: int,
2627
) -> Generator[List[List[Union[float, int]]], None, None]:
@@ -56,8 +57,16 @@ def _reduce(
5657
if batch_idx > fit_after_n_batches:
5758
yield self._transform(batch)
5859
else:
59-
embeddings = self.embedder.transform(documents)
60-
yield self._transform(embeddings)
60+
if as_generator:
61+
embeddings = [
62+
emb
63+
for batch in self.embedder.transform(documents, as_generator)
64+
for emb in batch
65+
]
66+
yield from util.batch(self._transform(embeddings), self.batch_size)
67+
else:
68+
embeddings = self.embedder.transform(documents)
69+
yield self._transform(embeddings)
6170

6271
@staticmethod
6372
def load(embedder: dict) -> "PCASentenceReducer":

src/embedders/extraction/reduce.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from spacy.tokens.doc import Doc
12
from typing import List, Generator, Union
23
import numpy as np
34
from src.embedders import PCAReducer, util
@@ -24,7 +25,11 @@ def _transform(
2425
return batch_unsqueezed
2526

2627
def _reduce(
27-
self, documents, fit_model, fit_after_n_batches
28+
self,
29+
documents: List[Union[str, Doc]],
30+
as_generator: bool,
31+
fit_model: bool,
32+
fit_after_n_batches: int,
2833
) -> Generator[List[List[List[Union[float, int]]]], None, None]:
2934
if fit_model:
3035
embeddings_training = []
@@ -60,5 +65,13 @@ def _reduce(
6065
if batch_idx > fit_after_n_batches:
6166
yield self._transform(batch)
6267
else:
63-
embeddings = self.embedder.transform(documents)
64-
yield self._transform(embeddings)
68+
if as_generator:
69+
embeddings = [
70+
emb
71+
for batch in self.embedder.transform(documents, as_generator)
72+
for emb in batch
73+
]
74+
yield from util.batch(self._transform(embeddings), self.batch_size)
75+
else:
76+
embeddings = self.embedder.transform(documents)
77+
yield self._transform(embeddings)

0 commit comments

Comments
 (0)