Skip to content

Commit 7c0b61b

Browse files
committed
Use update_every in ldamulticore.py
batch=True --> update_every=0 batch=False --> update_every=1 New: update_every >= 2 (as used in ldamodel.py)
1 parent adb74a3 commit 7c0b61b

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

gensim/models/atmodel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, author2doc=None, d
213213
* 'symmetric': (default) Uses a fixed symmetric prior of `1.0 / num_topics`,
214214
* 'auto': Learns an asymmetric prior from the corpus.
215215
update_every : int, optional
216-
Make updates in topic probability for latest mini-batch.
216+
Number of chunks to be iterated through before each M step of EM.
217+
Set to 0 for batch learning, > 1 for online iterative learning.
217218
eval_every : int, optional
218219
Calculate and estimate log perplexity for latest mini-batch.
219220
gamma_threshold : float, optional

gensim/models/ldamodel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def __init__(self, corpus=None, num_topics=100, id2word=None,
374374
passes : int, optional
375375
Number of passes through the corpus during training.
376376
update_every : int, optional
377-
Number of documents to be iterated through for each update.
377+
Number of chunks to be iterated through before each M step of EM.
378378
Set to 0 for batch learning, > 1 for online iterative learning.
379379
alpha : {float, numpy.ndarray of float, list of float, str}, optional
380380
A-priori belief on document-topic distribution, this can be:
@@ -1053,7 +1053,7 @@ def do_mstep(self, rho, other, extra_pass=False):
10531053
----------
10541054
rho : float
10551055
Learning rate.
1056-
other : :class:`~gensim.models.ldamodel.LdaModel`
1056+
other : :class:`~gensim.models.ldamodel.LdaState`
10571057
The model whose sufficient statistics will be used to update the topics.
10581058
extra_pass : bool, optional
10591059
Whether this step required an additional pass over the corpus.

gensim/models/ldamulticore.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class LdaMulticore(LdaModel):
107107
108108
"""
109109
def __init__(self, corpus=None, num_topics=100, id2word=None, workers=None,
110-
chunksize=2000, passes=1, batch=False, alpha='symmetric',
110+
chunksize=2000, passes=1, update_every=1, alpha='symmetric',
111111
eta=None, decay=0.5, offset=1.0, eval_every=10, iterations=50,
112112
gamma_threshold=0.001, random_state=None, minimum_probability=0.01,
113113
minimum_phi_value=0.01, per_word_topics=False, dtype=np.float32):
@@ -133,6 +133,9 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, workers=None,
133133
Number of documents to be used in each training chunk.
134134
passes : int, optional
135135
Number of passes through the corpus during training.
136+
update_every : int, optional
137+
Number of chunks to be iterated through before each M step of EM.
138+
Set to 0 for batch learning, > 1 for online iterative learning.
136139
alpha : {float, numpy.ndarray of float, list of float, str}, optional
137140
A-priori belief on document-topic distribution, this can be:
138141
* scalar for a symmetric prior over document-topic distribution,
@@ -178,24 +181,20 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, workers=None,
178181
179182
"""
180183
self.workers = max(1, cpu_count() - 1) if workers is None else workers
181-
self.batch = batch
182184

183185
if isinstance(alpha, str) and alpha == 'auto':
184186
raise NotImplementedError("auto-tuning alpha not implemented in LdaMulticore; use plain LdaModel.")
185187

186188
super(LdaMulticore, self).__init__(
187-
corpus=corpus, num_topics=num_topics,
188-
id2word=id2word, chunksize=chunksize, passes=passes, alpha=alpha, eta=eta,
189+
corpus=corpus, num_topics=num_topics, id2word=id2word, distributed=False, # not distributed across machines
190+
chunksize=chunksize, passes=passes, update_every=update_every, alpha=alpha, eta=eta,
189191
decay=decay, offset=offset, eval_every=eval_every, iterations=iterations,
190-
gamma_threshold=gamma_threshold, random_state=random_state, minimum_probability=minimum_probability,
192+
gamma_threshold=gamma_threshold, minimum_probability=minimum_probability, random_state=random_state,
191193
minimum_phi_value=minimum_phi_value, per_word_topics=per_word_topics, dtype=dtype,
192194
)
193195

194196
def update(self, corpus, chunks_as_numpy=False):
195-
"""Train the model with new documents, by EM-iterating over `corpus` until the topics converge
196-
(or until the maximum number of allowed iterations is reached).
197-
198-
Train the model with new documents, by EM-iterating over the corpus until the topics converge, or until
197+
"""Train the model with new documents, by EM-iterating over the corpus until the topics converge, or until
199198
the maximum number of allowed iterations is reached. `corpus` must be an iterable. The E step is distributed
200199
into the several processes.
201200
@@ -231,14 +230,20 @@ def update(self, corpus, chunks_as_numpy=False):
231230

232231
self.state.numdocs += lencorpus
233232

234-
if self.batch:
233+
# Same as in LdaModel but self.workers (processes) is used instead of self.numworkers (machines)
234+
if self.update_every:
235+
updatetype = "online"
236+
if self.passes == 1:
237+
updatetype += " (single-pass)"
238+
else:
239+
updatetype += " (multi-pass)"
240+
updateafter = min(lencorpus, self.update_every * self.workers * self.chunksize)
241+
else:
235242
updatetype = "batch"
236243
updateafter = lencorpus
237-
else:
238-
updatetype = "online"
239-
updateafter = self.chunksize * self.workers
244+
240245
eval_every = self.eval_every or 0
241-
evalafter = min(lencorpus, eval_every * updateafter)
246+
evalafter = min(lencorpus, eval_every * self.workers * self.chunksize)
242247

243248
updates_per_pass = max(1, lencorpus / updateafter)
244249
logger.info(

0 commit comments

Comments
 (0)