Skip to content

Commit 9ac7d09

Browse files
committed
Merge branch 'dev' of github.com:fastnlp/fastNLP into dev
2 parents 24c5087 + c18b205 commit 9ac7d09

File tree

3 files changed

+287
-6
lines changed

3 files changed

+287
-6
lines changed

fastNLP/io/file_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@
103103
"yelp-review-polarity": "yelp_review_polarity.tar.gz",
104104
"sst-2": "SST-2.zip",
105105
"sst": "SST.zip",
106+
'mr': 'mr.zip',
107+
"R8": "R8.zip",
108+
"R52": "R52.zip",
109+
"20ng": "20ng.zip",
110+
"ohsumed": "ohsumed.zip",
106111

107112
# Classification, Chinese
108113
"chn-senti-corp": "chn_senti_corp.zip",

fastNLP/io/pipe/__init__.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@
2323
"ChnSentiCorpPipe",
2424
"THUCNewsPipe",
2525
"WeiboSenti100kPipe",
26-
"MRPipe", "R52Pipe", "R8Pipe", "OhsumedPipe", "NG20Loader",
27-
26+
"MRPipe", "R52Pipe", "R8Pipe", "OhsumedPipe", "NG20Pipe",
27+
2828
"Conll2003NERPipe",
2929
"OntoNotesNERPipe",
3030
"MsraNERPipe",
3131
"WeiboNERPipe",
3232
"PeopleDailyPipe",
3333
"Conll2003Pipe",
34-
34+
3535
"MatchingBertPipe",
3636
"RTEBertPipe",
3737
"SNLIBertPipe",
@@ -53,14 +53,20 @@
5353
"RenamePipe",
5454
"GranularizePipe",
5555
"MachingTruncatePipe",
56-
56+
5757
"CoReferencePipe",
5858

59-
"CMRC2018BertPipe"
59+
"CMRC2018BertPipe",
60+
61+
"R52PmiGraphPipe",
62+
"R8PmiGraphPipe",
63+
"OhsumedPmiGraphPipe",
64+
"NG20PmiGraphPipe",
65+
"MRPmiGraphPipe"
6066
]
6167

6268
from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \
63-
WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Loader
69+
WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Pipe
6470
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe
6571
from .conll import Conll2003Pipe
6672
from .coreference import CoReferencePipe
@@ -70,3 +76,5 @@
7076
LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe
7177
from .pipe import Pipe
7278
from .qa import CMRC2018BertPipe
79+
80+
from .construct_graph import MRPmiGraphPipe, R8PmiGraphPipe, R52PmiGraphPipe, NG20PmiGraphPipe, OhsumedPmiGraphPipe

fastNLP/io/pipe/construct_graph.py

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
2+
__all__ =[
3+
'MRPmiGraphPipe',
4+
'R8PmiGraphPipe',
5+
'R52PmiGraphPipe',
6+
'OhsumedPmiGraphPipe',
7+
'NG20PmiGraphPipe'
8+
]
9+
try:
10+
import networkx as nx
11+
from sklearn.feature_extraction.text import CountVectorizer
12+
from sklearn.feature_extraction.text import TfidfTransformer
13+
from sklearn.pipeline import Pipeline
14+
except:
15+
pass
16+
from collections import defaultdict
17+
import itertools
18+
import math
19+
from tqdm import tqdm
20+
import numpy as np
21+
22+
from ..data_bundle import DataBundle
23+
from ...core.const import Const
24+
from ..loader.classification import MRLoader, OhsumedLoader, R52Loader, R8Loader, NG20Loader
25+
26+
27+
def _get_windows(content_lst: list, window_size:int):
28+
r"""
29+
滑动窗口处理文本,获取词频和共现词语的词频
30+
:param content_lst:
31+
:param window_size:
32+
:return: 词频,共现词频,窗口化后文本段的数量
33+
"""
34+
word_window_freq = defaultdict(int) # w(i) 单词在窗口单位内出现的次数
35+
word_pair_count = defaultdict(int) # w(i, j)
36+
windows_len = 0
37+
for words in tqdm(content_lst, desc="Split by window"):
38+
windows = list()
39+
40+
if isinstance(words, str):
41+
words = words.split()
42+
length = len(words)
43+
44+
if length <= window_size:
45+
windows.append(words)
46+
else:
47+
for j in range(length - window_size + 1):
48+
window = words[j: j + window_size]
49+
windows.append(list(set(window)))
50+
51+
for window in windows:
52+
for word in window:
53+
word_window_freq[word] += 1
54+
55+
for word_pair in itertools.combinations(window, 2):
56+
word_pair_count[word_pair] += 1
57+
58+
windows_len += len(windows)
59+
return word_window_freq, word_pair_count, windows_len
60+
61+
def _cal_pmi(W_ij, W, word_freq_i, word_freq_j):
62+
r"""
63+
params: w_ij:为词语i,j的共现词频
64+
w:文本数量
65+
word_freq_i: 词语i的词频
66+
word_freq_j: 词语j的词频
67+
return: 词语i,j的tfidf值
68+
"""
69+
p_i = word_freq_i / W
70+
p_j = word_freq_j / W
71+
p_i_j = W_ij / W
72+
pmi = math.log(p_i_j / (p_i * p_j))
73+
74+
return pmi
75+
76+
def _count_pmi(windows_len, word_pair_count, word_window_freq, threshold):
77+
r"""
78+
params: windows_len: 文本段数量
79+
word_pair_count: 词共现频率字典
80+
word_window_freq: 词频率字典
81+
threshold: 阈值
82+
return 词语pmi的list列表,其中元素为[word1, word2, pmi]
83+
"""
84+
word_pmi_lst = list()
85+
for word_pair, W_i_j in tqdm(word_pair_count.items(), desc="Calculate pmi between words"):
86+
word_freq_1 = word_window_freq[word_pair[0]]
87+
word_freq_2 = word_window_freq[word_pair[1]]
88+
89+
pmi = _cal_pmi(W_i_j, windows_len, word_freq_1, word_freq_2)
90+
if pmi <= threshold:
91+
continue
92+
word_pmi_lst.append([word_pair[0], word_pair[1], pmi])
93+
return word_pmi_lst
94+
95+
class GraphBuilderBase:
96+
def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
97+
self.graph = nx.Graph()
98+
self.word2id = dict()
99+
self.graph_type = graph_type
100+
self.window_size = widow_size
101+
self.doc_node_num = 0
102+
self.tr_doc_index = None
103+
self.te_doc_index = None
104+
self.dev_doc_index = None
105+
self.doc = None
106+
self.threshold = threshold
107+
108+
def _get_doc_edge(self, data_bundle: DataBundle):
109+
r'''
110+
对输入的DataBundle进行处理,然后生成文档-单词的tfidf值
111+
:param: data_bundle中的文本若为英文,形式为[ 'This is the first document.'],若为中文则为['他 喜欢 吃 苹果']
112+
: return 返回带有具有tfidf边文档-单词稀疏矩阵
113+
'''
114+
tr_doc = list(data_bundle.get_dataset("train").get_field(Const.RAW_WORD))
115+
val_doc = list(data_bundle.get_dataset("dev").get_field(Const.RAW_WORD))
116+
te_doc = list(data_bundle.get_dataset("test").get_field(Const.RAW_WORD))
117+
doc = tr_doc + val_doc + te_doc
118+
self.doc = doc
119+
self.tr_doc_index = [ind for ind in range(len(tr_doc))]
120+
self.dev_doc_index = [ind+len(tr_doc) for ind in range(len(val_doc))]
121+
self.te_doc_index = [ind+len(tr_doc)+len(val_doc) for ind in range(len(te_doc))]
122+
text_tfidf = Pipeline([('count', CountVectorizer(token_pattern=r'\S+', min_df=1, max_df=1.0)),
123+
('tfidf', TfidfTransformer(norm=None, use_idf=True, smooth_idf=False, sublinear_tf=False))])
124+
125+
tfidf_vec = text_tfidf.fit_transform(doc)
126+
self.doc_node_num = tfidf_vec.shape[0]
127+
vocab_lst = text_tfidf['count'].get_feature_names()
128+
for ind, word in enumerate(vocab_lst):
129+
self.word2id[word] = ind
130+
for ind, row in enumerate(tfidf_vec):
131+
for col_index, value in zip(row.indices, row.data):
132+
self.graph.add_edge(ind, self.doc_node_num+col_index, weight=value)
133+
return nx.to_scipy_sparse_matrix(self.graph)
134+
135+
def _get_word_edge(self):
136+
word_window_freq, word_pair_count, windows_len = _get_windows(self.doc, self.window_size)
137+
pmi_edge_lst = _count_pmi(windows_len, word_pair_count, word_window_freq, self.threshold)
138+
for edge_item in pmi_edge_lst:
139+
word_indx1 = self.doc_node_num + self.word2id[edge_item[0]]
140+
word_indx2 = self.doc_node_num + self.word2id[edge_item[1]]
141+
if word_indx1 == word_indx2:
142+
continue
143+
self.graph.add_edge(word_indx1, word_indx2, weight=edge_item[2])
144+
145+
def build_graph(self, data_bundle: DataBundle):
146+
r"""
147+
对输入的DataBundle进行处理,然后返回该scipy_sparse_matrix类型的邻接矩阵。
148+
149+
:param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象
150+
:return:
151+
"""
152+
raise NotImplementedError
153+
154+
def build_graph_from_file(self, path: str):
155+
r"""
156+
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`
157+
158+
:param paths:
159+
:return: scipy_sparse_matrix
160+
"""
161+
raise NotImplementedError
162+
163+
164+
class MRPmiGraphPipe(GraphBuilderBase):
165+
166+
def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
167+
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)
168+
169+
def build_graph(self, data_bundle: DataBundle):
170+
r'''
171+
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
172+
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
173+
'''
174+
self._get_doc_edge(data_bundle)
175+
self._get_word_edge()
176+
return nx.to_scipy_sparse_matrix(self.graph,
177+
nodelist=list(range(self.graph.number_of_nodes())),
178+
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)
179+
180+
def build_graph_from_file(self, path: str):
181+
data_bundle = MRLoader().load(path)
182+
return self.build_graph(data_bundle)
183+
184+
class R8PmiGraphPipe(GraphBuilderBase):
185+
186+
def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
187+
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)
188+
189+
def build_graph(self, data_bundle: DataBundle):
190+
r'''
191+
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
192+
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
193+
'''
194+
self._get_doc_edge(data_bundle)
195+
self._get_word_edge()
196+
return nx.to_scipy_sparse_matrix(self.graph,
197+
nodelist=list(range(self.graph.number_of_nodes())),
198+
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)
199+
200+
def build_graph_from_file(self, path: str):
201+
data_bundle = R8Loader().load(path)
202+
return self.build_graph(data_bundle)
203+
204+
class R52PmiGraphPipe(GraphBuilderBase):
205+
206+
def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
207+
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)
208+
209+
def build_graph(self, data_bundle: DataBundle):
210+
r'''
211+
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
212+
return 返回csr类型的稀疏矩阵;训练集,验证集,测试集,在图中的index.
213+
'''
214+
self._get_doc_edge(data_bundle)
215+
self._get_word_edge()
216+
return nx.to_scipy_sparse_matrix(self.graph,
217+
nodelist=list(range(self.graph.number_of_nodes())),
218+
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)
219+
220+
def build_graph_from_file(self, path: str):
221+
data_bundle = R52Loader().load(path)
222+
return self.build_graph(data_bundle)
223+
224+
class OhsumedPmiGraphPipe(GraphBuilderBase):
225+
226+
def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
227+
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)
228+
229+
def build_graph(self, data_bundle: DataBundle):
230+
r'''
231+
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
232+
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
233+
'''
234+
self._get_doc_edge(data_bundle)
235+
self._get_word_edge()
236+
return nx.to_scipy_sparse_matrix(self.graph,
237+
nodelist=list(range(self.graph.number_of_nodes())),
238+
weight='weight', dtype=np.float32, format='csr'), (self.tr_doc_index, self.dev_doc_index, self.te_doc_index)
239+
240+
def build_graph_from_file(self, path: str):
241+
data_bundle = OhsumedLoader().load(path)
242+
return self.build_graph(data_bundle)
243+
244+
245+
class NG20PmiGraphPipe(GraphBuilderBase):
246+
247+
def __init__(self, graph_type='pmi', widow_size=10, threshold=0.):
248+
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold)
249+
250+
def build_graph(self, data_bundle: DataBundle):
251+
r'''
252+
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象.
253+
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
254+
'''
255+
self._get_doc_edge(data_bundle)
256+
self._get_word_edge()
257+
return nx.to_scipy_sparse_matrix(self.graph,
258+
nodelist=list(range(self.graph.number_of_nodes())),
259+
weight='weight', dtype=np.float32, format='csr'), (
260+
self.tr_doc_index, self.dev_doc_index, self.te_doc_index)
261+
262+
def build_graph_from_file(self, path: str):
263+
r'''
264+
param: path->数据集的路径.
265+
return: 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index.
266+
'''
267+
data_bundle = NG20Loader().load(path)
268+
return self.build_graph(data_bundle)

0 commit comments

Comments
 (0)