File tree Expand file tree Collapse file tree 3 files changed +30
-7
lines changed Expand file tree Collapse file tree 3 files changed +30
-7
lines changed Original file line number Diff line number Diff line change @@ -165,6 +165,7 @@ def load_pca_weights(self, file_name: str):
165165 def _reduce (
166166 self ,
167167 documents : List [Union [str , Doc ]],
168+ as_generator : bool ,
168169 fit_model : bool ,
169170 fit_after_n_batches : int ,
170171 ):
@@ -178,11 +179,11 @@ def _reduce_batch(
178179 fit_after_n_batches : int ,
179180 ) -> Union [List , Generator ]:
180181 if as_generator :
181- return self ._reduce (documents , fit_model , fit_after_n_batches )
182+ return self ._reduce (documents , as_generator , fit_model , fit_after_n_batches )
182183 else :
183184 embeddings = []
184185 for embedding_batch in self ._reduce (
185- documents , fit_model , fit_after_n_batches
186+ documents , as_generator , fit_model , fit_after_n_batches
186187 ):
187188 embeddings .extend (embedding_batch )
188189 return embeddings
Original file line number Diff line number Diff line change @@ -21,6 +21,7 @@ def _transform(
2121 def _reduce (
2222 self ,
2323 documents : List [Union [str , Doc ]],
24+ as_generator : bool ,
2425 fit_model : bool ,
2526 fit_after_n_batches : int ,
2627 ) -> Generator [List [List [Union [float , int ]]], None , None ]:
@@ -56,8 +57,16 @@ def _reduce(
5657 if batch_idx > fit_after_n_batches :
5758 yield self ._transform (batch )
5859 else :
59- embeddings = self .embedder .transform (documents )
60- yield self ._transform (embeddings )
60+ if as_generator :
61+ embeddings = [
62+ emb
63+ for batch in self .embedder .transform (documents , as_generator )
64+ for emb in batch
65+ ]
66+ yield from util .batch (self ._transform (embeddings ), self .batch_size )
67+ else :
68+ embeddings = self .embedder .transform (documents )
69+ yield self ._transform (embeddings )
6170
6271 @staticmethod
6372 def load (embedder : dict ) -> "PCASentenceReducer" :
Original file line number Diff line number Diff line change 1+ from spacy .tokens .doc import Doc
12from typing import List , Generator , Union
23import numpy as np
34from src .embedders import PCAReducer , util
@@ -24,7 +25,11 @@ def _transform(
2425 return batch_unsqueezed
2526
2627 def _reduce (
27- self , documents , fit_model , fit_after_n_batches
28+ self ,
29+ documents : List [Union [str , Doc ]],
30+ as_generator : bool ,
31+ fit_model : bool ,
32+ fit_after_n_batches : int ,
2833 ) -> Generator [List [List [List [Union [float , int ]]]], None , None ]:
2934 if fit_model :
3035 embeddings_training = []
@@ -60,5 +65,13 @@ def _reduce(
6065 if batch_idx > fit_after_n_batches :
6166 yield self ._transform (batch )
6267 else :
63- embeddings = self .embedder .transform (documents )
64- yield self ._transform (embeddings )
68+ if as_generator :
69+ embeddings = [
70+ emb
71+ for batch in self .embedder .transform (documents , as_generator )
72+ for emb in batch
73+ ]
74+ yield from util .batch (self ._transform (embeddings ), self .batch_size )
75+ else :
76+ embeddings = self .embedder .transform (documents )
77+ yield self ._transform (embeddings )
You can’t perform that action at this time.
0 commit comments