diff --git a/gensim/models/atmodel.py b/gensim/models/atmodel.py index 838c7634e3..b0720c2f19 100755 --- a/gensim/models/atmodel.py +++ b/gensim/models/atmodel.py @@ -213,7 +213,8 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, author2doc=None, d * 'symmetric': (default) Uses a fixed symmetric prior of `1.0 / num_topics`, * 'auto': Learns an asymmetric prior from the corpus. update_every : int, optional - Make updates in topic probability for latest mini-batch. + Number of chunks to be iterated through before each M step of EM. + Set to 0 for batch learning, > 1 for online iterative learning. eval_every : int, optional Calculate and estimate log perplexity for latest mini-batch. gamma_threshold : float, optional @@ -803,7 +804,7 @@ def update(self, corpus=None, author2doc=None, doc2author=None, chunksize=None, self.state.numdocs += lencorpus if update_every: - updatetype = "online" + updatetype = "online (single-pass)" if self.passes == 1 else "online (multi-pass)" updateafter = min(lencorpus, update_every * self.numworkers * chunksize) else: updatetype = "batch" diff --git a/gensim/models/ldamodel.py b/gensim/models/ldamodel.py index 6691ddcc31..ac0d0691a1 100755 --- a/gensim/models/ldamodel.py +++ b/gensim/models/ldamodel.py @@ -374,7 +374,7 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, passes : int, optional Number of passes through the corpus during training. update_every : int, optional - Number of documents to be iterated through for each update. + Number of chunks to be iterated through before each M step of EM. Set to 0 for batch learning, > 1 for online iterative learning. alpha : {float, numpy.ndarray of float, list of float, str}, optional A-priori belief on document-topic distribution, this can be: @@ -932,11 +932,7 @@ def update(self, corpus, chunksize=None, decay=None, offset=None, self.state.numdocs += lencorpus if update_every: - updatetype = "online" - if passes == 1: - updatetype += " (single-pass)" - else: - updatetype += " (multi-pass)" + updatetype = "online (single-pass)" if self.passes == 1 else "online (multi-pass)" updateafter = min(lencorpus, update_every * self.numworkers * chunksize) else: updatetype = "batch" @@ -1053,7 +1049,7 @@ def do_mstep(self, rho, other, extra_pass=False): ---------- rho : float Learning rate. - other : :class:`~gensim.models.ldamodel.LdaModel` + other : :class:`~gensim.models.ldamodel.LdaState` The model whose sufficient statistics will be used to update the topics. extra_pass : bool, optional Whether this step required an additional pass over the corpus. diff --git a/gensim/models/ldamulticore.py b/gensim/models/ldamulticore.py index fdb5ce70a9..4171bfb35b 100644 --- a/gensim/models/ldamulticore.py +++ b/gensim/models/ldamulticore.py @@ -107,7 +107,7 @@ class LdaMulticore(LdaModel): """ def __init__(self, corpus=None, num_topics=100, id2word=None, workers=None, - chunksize=2000, passes=1, batch=False, alpha='symmetric', + chunksize=2000, passes=1, update_every=1, alpha='symmetric', eta=None, decay=0.5, offset=1.0, eval_every=10, iterations=50, gamma_threshold=0.001, random_state=None, minimum_probability=0.01, minimum_phi_value=0.01, per_word_topics=False, dtype=np.float32): @@ -133,6 +133,9 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, workers=None, Number of documents to be used in each training chunk. passes : int, optional Number of passes through the corpus during training. + update_every : int, optional + Number of chunks to be iterated through before each M step of EM. + Set to 0 for batch learning, > 1 for online iterative learning. alpha : {float, numpy.ndarray of float, list of float, str}, optional A-priori belief on document-topic distribution, this can be: * scalar for a symmetric prior over document-topic distribution, @@ -178,24 +181,20 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, workers=None, """ self.workers = max(1, cpu_count() - 1) if workers is None else workers - self.batch = batch if isinstance(alpha, str) and alpha == 'auto': raise NotImplementedError("auto-tuning alpha not implemented in LdaMulticore; use plain LdaModel.") super(LdaMulticore, self).__init__( - corpus=corpus, num_topics=num_topics, - id2word=id2word, chunksize=chunksize, passes=passes, alpha=alpha, eta=eta, + corpus=corpus, num_topics=num_topics, id2word=id2word, distributed=False, # not distributed across machines + chunksize=chunksize, passes=passes, update_every=update_every, alpha=alpha, eta=eta, decay=decay, offset=offset, eval_every=eval_every, iterations=iterations, - gamma_threshold=gamma_threshold, random_state=random_state, minimum_probability=minimum_probability, + gamma_threshold=gamma_threshold, minimum_probability=minimum_probability, random_state=random_state, minimum_phi_value=minimum_phi_value, per_word_topics=per_word_topics, dtype=dtype, ) def update(self, corpus, chunks_as_numpy=False): - """Train the model with new documents, by EM-iterating over `corpus` until the topics converge - (or until the maximum number of allowed iterations is reached). - - Train the model with new documents, by EM-iterating over the corpus until the topics converge, or until + """Train the model with new documents, by EM-iterating over the corpus until the topics converge, or until the maximum number of allowed iterations is reached. `corpus` must be an iterable. The E step is distributed into the several processes. @@ -231,14 +230,16 @@ def update(self, corpus, chunks_as_numpy=False): self.state.numdocs += lencorpus - if self.batch: + # Same as in LdaModel but self.workers (processes) is used instead of self.numworkers (machines) + if self.update_every: + updatetype = "online (single-pass)" if self.passes == 1 else "online (multi-pass)" + updateafter = min(lencorpus, self.update_every * self.workers * self.chunksize) + else: updatetype = "batch" updateafter = lencorpus - else: - updatetype = "online" - updateafter = self.chunksize * self.workers + eval_every = self.eval_every or 0 - evalafter = min(lencorpus, eval_every * updateafter) + evalafter = min(lencorpus, eval_every * self.workers * self.chunksize) updates_per_pass = max(1, lencorpus / updateafter) logger.info(