Skip to content

Commit 1e1898b

Browse files
authored
Add LDA.py
1 parent 20cf886 commit 1e1898b

File tree

1 file changed

+187
-0
lines changed

1 file changed

+187
-0
lines changed

LDA/LDA.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
#coding=utf-8
2+
#Author:Harold
3+
#Date:2021-1-27
4+
5+
6+
'''
7+
数据集:bbc_text
8+
数据集数量:2225
9+
-----------------------------
10+
运行结果:
11+
话题数:5
12+
原始话题:'tech', 'business', 'sport', 'entertainment', 'politics'
13+
生成话题:
14+
1:'said game england would time first back play last good'
15+
2:'said year would economy growth also economic bank government could'
16+
3:'said year games sales company also market last firm 2004'
17+
4:'film said music best also people year show number digital'
18+
5:'said would people government labour election party blair could also'
19+
运行时长:7620.51s
20+
'''
21+
22+
import numpy as np
23+
import pandas as pd
24+
import string
25+
from nltk.corpus import stopwords
26+
import time
27+
28+
29+
#定义加载数据的函数
30+
def load_data(file, K):
31+
'''
32+
INPUT:
33+
file - (str) 数据文件的路径
34+
K - (int) 设定的话题数
35+
36+
OUTPUT:
37+
org_topics - (list) 原始话题标签列表
38+
text - (list) 文本列表
39+
words - (list) 单词列表
40+
alpha - (list) 话题概率分布,模型超参数
41+
beta - (list) 单词概率分布,模型超参数
42+
43+
'''
44+
df = pd.read_csv(file) #读取文件
45+
org_topics = df['category'].unique().tolist() #保存文本原始的话题标签
46+
M = df.shape[0] #文本数
47+
alpha = np.zeros(K) #alpha是LDA模型的一个超参数,是对话题概率的预估计,这里取文本数据中各话题的比例作为alpha值,实际可以通过模型训练得到
48+
beta = np.zeros(1000) #beta是LDA模型的另一个超参数,是词汇表中单词的概率分布,这里取各单词在所有文本中的比例作为beta值,实际也可以通过模型训练得到
49+
#计算各话题的比例作为alpha值
50+
for k, topic in enumerate(org_topics):
51+
alpha[k] = df[df['category'] == topic].shape[0] / M
52+
df.drop('category', axis=1, inplace=True)
53+
n = df.shape[0] #n为文本数量
54+
text = []
55+
words = []
56+
for i in df['text'].values:
57+
t = i.translate(str.maketrans('', '', string.punctuation)) #去除文本中的标点符号
58+
t = [j for j in t.split() if j not in stopwords.words('english')] #去除文本中的停止词
59+
t = [j for j in t if len(j) > 3] #长度小于等于3的单词大多是无意义的,直接去除
60+
text.append(t) #将处理后的文本保存到文本列表中
61+
words.extend(set(t)) #将文本中所包含的单词保存到单词列表中
62+
words = list(set(words)) #去除单词列表中的重复单词
63+
words_cnt = np.zeros(len(words)) #用来保存单词的出现频次
64+
#循环计算words列表中各单词出现的词频
65+
for i in range(len(text)):
66+
t = text[i] #取出第i条文本
67+
for w in t:
68+
ind = words.index(w) #取出第i条文本中的第t个单词在单词列表中的索引
69+
words_cnt[ind] += 1 #对应位置的单词出现频次加一
70+
sort_inds = np.argsort(words_cnt)[::-1] #对单词出现频次降序排列后取出其索引值
71+
words = [words[ind] for ind in sort_inds[:1000]] #将出现频次前1000的单词保存到words列表
72+
#去除文本text中不在词汇表words中的单词
73+
for i in range(len(text)):
74+
t = []
75+
for w in text[i]:
76+
if w in words:
77+
ind = words.index(w)
78+
t.append(w)
79+
beta[ind] += 1 #统计各单词在文本中的出现频次
80+
text[i] = t
81+
beta /= np.sum(beta) #除以文本的总单词数得到各单词所占比例,作为beta值
82+
return org_topics, text, words, alpha, beta
83+
84+
85+
#定义潜在狄利克雷分配函数,采用收缩的吉布斯抽样算法估计模型的参数theta和phi
86+
def do_lda(text, words, alpha, beta, K, iters):
87+
'''
88+
INPUT:
89+
text - (list) 文本列表
90+
words - (list) 单词列表
91+
alpha - (list) 话题概率分布,模型超参数
92+
beta - (list) 单词概率分布,模型超参数
93+
K - (int) 设定的话题数
94+
iters - (int) 设定的迭代次数
95+
96+
OUTPUT:
97+
theta - (array) 话题的条件概率分布p(zk|dj),这里写成p(zk|dj)是为了和PLSA模型那一章的符号统一一下,方便对照着看
98+
phi - (array) 单词的条件概率分布p(wi|zk)
99+
100+
'''
101+
M = len(text) #文本数
102+
V = len(words) #单词数
103+
N_MK = np.zeros((M, K)) #文本-话题计数矩阵
104+
N_KV = np.zeros((K, V)) #话题-单词计数矩阵
105+
N_M = np.zeros(M) #文本计数向量
106+
N_K = np.zeros(K) #话题计数向量
107+
Z_MN = [] #用来保存每条文本的每个单词所在位置处抽样得到的话题
108+
#算法20.2的步骤(2),对每个文本的所有单词抽样产生话题,并进行计数
109+
for m in range(M):
110+
zm = []
111+
t = text[m]
112+
for n, w in enumerate(t):
113+
v = words.index(w)
114+
z = np.random.randint(K)
115+
zm.append(z)
116+
N_MK[m, z] += 1
117+
N_M[m] += 1
118+
N_KV[z, v] += 1
119+
N_K[z] += 1
120+
Z_MN.append(zm)
121+
#算法20.2的步骤(3),多次迭代进行吉布斯抽样
122+
for i in range(iters):
123+
print('{}/{}'.format(i+1, iters))
124+
for m in range(M):
125+
t = text[m]
126+
for n, w in enumerate(t):
127+
v = words.index(w)
128+
z = Z_MN[m][n]
129+
N_MK[m, z] -= 1
130+
N_M[m] -= 1
131+
N_KV[z][v] -= 1
132+
N_K[z] -= 1
133+
p = [] #用来保存对K个话题的条件分布p(zi|z_i,w,alpha,beta)的计算结果
134+
sums_k = 0
135+
for k in range(K):
136+
p_zk = (N_KV[k][v] + beta[v]) * (N_MK[m][k] + alpha[k]) #话题zi=k的条件分布p(zi|z_i,w,alpha,beta)的分子部分
137+
sums_v = 0
138+
sums_k += N_MK[m][k] + alpha[k] #累计(nmk + alpha_k)在K个话题上的和
139+
for t in range(V):
140+
sums_v += N_KV[k][t] + beta[t] #累计(nkv + beta_v)在V个单词上的和
141+
p_zk /= sums_v
142+
p.append(p_zk)
143+
p = p / sums_k
144+
p = p / np.sum(p) #对条件分布p(zi|z_i,w,alpha,beta)进行归一化,保证概率的总和为1
145+
new_z = np.random.choice(a=K, p=p) #根据以上计算得到的概率进行抽样,得到新的话题
146+
Z_MN[m][n] = new_z #更新当前位置处的话题为上面抽样得到的新话题
147+
#更新计数
148+
N_MK[m, new_z] += 1
149+
N_M[m] += 1
150+
N_KV[new_z, v] += 1
151+
N_K[new_z] += 1
152+
#算法20.2的步骤(4),利用得到的样本计数,估计模型的参数theta和phi
153+
theta = np.zeros((M, K))
154+
phi = np.zeros((K, V))
155+
for m in range(M):
156+
sums_k = 0
157+
for k in range(K):
158+
theta[m, k] = N_MK[m][k] + alpha[k] #参数theta的分子部分
159+
sums_k += theta[m, k] #累计(nmk + alpha_k)在K个话题上的和,参数theta的分母部分
160+
theta[m] /= sums_k #计算参数theta
161+
for k in range(K):
162+
sums_v = 0
163+
for v in range(V):
164+
phi[k, v] = N_KV[k][v] + beta[v] #参数phi的分子部分
165+
sums_v += phi[k][v] #累计(nkv + beta_v)在V个单词上的和,参数phi的分母部分
166+
phi[k] /= sums_v #计算参数phi
167+
return theta, phi
168+
169+
170+
if __name__ == "__main__":
171+
K = 5 #设定话题数为5
172+
org_topics, text, words, alpha, beta = load_data('bbc_text.csv', K) #加载数据
173+
print('Original Topics:')
174+
print(org_topics) #打印原始的话题标签列表
175+
start = time.time() #保存开始时间
176+
iters = 10 #为了避免运行时间过长,这里只迭代10次,实际上10次是不够的,要迭代足够的次数保证吉布斯抽样进入燃烧期,这样得到的参数才能尽可能接近样本的实际概率分布
177+
theta, phi = do_lda(text, words, alpha, beta, K, iters) #LDA的吉布斯抽样
178+
#打印出每个话题zk条件下出现概率最大的前10个单词,即P(wi|zk)在话题zk中最大的10个值对应的单词,作为对话题zk的文本描述
179+
for k in range(K):
180+
sort_inds = np.argsort(phi[k])[::-1] #对话题zk条件下的P(wi|zk)的值进行降序排列后取出对应的索引值
181+
topic = [] #定义一个空列表用于保存话题zk概率最大的前10个单词
182+
for i in range(10):
183+
topic.append(words[sort_inds[i]])
184+
topic = ' '.join(topic) #将10个单词以空格分隔,构成对话题zk的文本表述
185+
print('Topic {}: {}'.format(k+1, topic)) #打印话题zk
186+
end = time.time()
187+
print('Time:', end-start)

0 commit comments

Comments
 (0)