1+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
115from __future__ import print_function
216import numpy as np
3- import re ,random
17+ import re , random
418from paddle .io import IterableDataset
519
620
@@ -16,38 +30,50 @@ def __init__(self, file_list, config):
1630 elif re .match ('[\\ S]*article[0-9]*.txt$' , x ) != None :
1731 self .article_file_list .append (x )
1832 self .config = config
19- self .article_content_size = config .get ("hyper_parameters.article_content_size" )
20- self .article_title_size = config .get ("hyper_parameters.article_title_size" )
33+ self .article_content_size = config .get (
34+ "hyper_parameters.article_content_size" )
35+ self .article_title_size = config .get (
36+ "hyper_parameters.article_title_size" )
2137 self .browse_size = config .get ("hyper_parameters.browse_size" )
22- self .neg_condidate_sample_size = config .get ("hyper_parameters.neg_condidate_sample_size" )
23- self .word_dict_size = int (config .get ("hyper_parameters.word_dict_size" ))
38+ self .neg_condidate_sample_size = config .get (
39+ "hyper_parameters.neg_condidate_sample_size" )
40+ self .word_dict_size = int (
41+ config .get ("hyper_parameters.word_dict_size" ))
2442 self .category_size = int (config .get ("hyper_parameters.category_size" ))
25- self .sub_category_size = int (config .get ("hyper_parameters.sub_category_size" ))
43+ self .sub_category_size = int (
44+ config .get ("hyper_parameters.sub_category_size" ))
2645 self .article_map_cate = {}
2746 self .article_map_title = {}
2847 self .article_map_content = {}
2948 self .article_map_sub_cate = {}
3049 self .init ()
3150
32- def convert_unk (self ,id ):
51+ def convert_unk (self , id ):
3352 if id in self .article_map_cate :
3453 return id
3554 return "padding"
55+
3656 def init (self ):
3757 self .article_map_cate ["padding" ] = self .category_size
3858 self .article_map_sub_cate ["padding" ] = self .sub_category_size
39- self .article_map_title ["padding" ] = [self .word_dict_size ] * self .article_title_size
40- self .article_map_content ["padding" ] = [self .word_dict_size ]* self .article_content_size
59+ self .article_map_title ["padding" ] = [self .word_dict_size
60+ ] * self .article_title_size
61+ self .article_map_content ["padding" ] = [self .word_dict_size
62+ ] * self .article_content_size
4163 #line [0]id cate_id sub_cate_id [3]title content
4264 for file in self .article_file_list :
43- with open (file ,"r" ) as rf :
65+ with open (file , "r" ) as rf :
4466 for l in rf :
4567 line = l .strip ().split ('\t ' )
4668 id = line [0 ]
4769 #line 0 cate 1:subcate, 2:title, 3 content;
48- line = [[int (line [1 ])],[int (line [2 ])],[int (t ) for t in line [3 ].split (" " )],[int (t ) for t in line [4 ].split (" " )]]
49- line [2 ] += [self .word_dict_size ] * (self .article_title_size - len (line [2 ]))
50- line [3 ] += [self .word_dict_size ] * (self .article_content_size - len (line [3 ]))
70+ line = [[int (line [1 ])], [int (line [2 ])],
71+ [int (t ) for t in line [3 ].split (" " )],
72+ [int (t ) for t in line [4 ].split (" " )]]
73+ line [2 ] += [self .word_dict_size ] * (
74+ self .article_title_size - len (line [2 ]))
75+ line [3 ] += [self .word_dict_size ] * (
76+ self .article_content_size - len (line [3 ]))
5177 self .article_map_cate [id ] = line [0 ][0 ]
5278 self .article_map_sub_cate [id ] = line [1 ][0 ]
5379 if len (line [2 ]) > self .article_title_size :
@@ -77,29 +103,61 @@ def __iter__(self):
77103 line [0 ] += ["unk" ] * (self .browse_size - len (line [0 ]))
78104 neg_candidate = line [2 ]
79105 if len (neg_candidate ) < self .neg_condidate_sample_size :
80- continue ;
106+ continue
81107 candidate = neg_candidate [:self .neg_condidate_sample_size ]
82108 candidate .append (line [1 ][0 ])
83109 line [1 ] = []
84110 ids = list (range (self .neg_condidate_sample_size + 1 ))
85111 random .shuffle (ids )
86112 label = []
87113 for i in ids :
88- line [1 ].append (candidate [i ]) #1 condidate 0:visit
114+ line [1 ].append (candidate [i ]) #1 condidate 0:visit
89115 if i == self .neg_condidate_sample_size :
90116 label .append (1 )
91117 else :
92118 label .append (0 )
93119
94120 article_list = [np .array (label )]
95- # l = [self.article_map[i] for i in line[1]]
96- article_list .append (np .array ([self .article_map_cate [self .convert_unk (i )] for i in line [1 ]]))
97- article_list .append (np .array ([self .article_map_cate [self .convert_unk (i )] for i in line [0 ]]))
98- article_list .append (np .array ([self .article_map_sub_cate [self .convert_unk (i )] for i in line [1 ]]))
99- article_list .append (np .array ([self .article_map_sub_cate [self .convert_unk (i )] for i in line [0 ]]))
100- article_list .append (np .array ([self .article_map_title [self .convert_unk (i )] for i in line [1 ]]))
101- article_list .append (np .array ([self .article_map_title [self .convert_unk (i )] for i in line [0 ]]))
102- article_list .append (np .array ([self .article_map_content [self .convert_unk (i )] for i in line [1 ]]))
103- article_list .append (np .array ([self .article_map_content [self .convert_unk (i )] for i in line [0 ]]))
121+ # l = [self.article_map[i] for i in line[1]]
122+ article_list .append (
123+ np .array ([
124+ self .article_map_cate [self .convert_unk (i )]
125+ for i in line [1 ]
126+ ]))
127+ article_list .append (
128+ np .array ([
129+ self .article_map_cate [self .convert_unk (i )]
130+ for i in line [0 ]
131+ ]))
132+ article_list .append (
133+ np .array ([
134+ self .article_map_sub_cate [self .convert_unk (i )]
135+ for i in line [1 ]
136+ ]))
137+ article_list .append (
138+ np .array ([
139+ self .article_map_sub_cate [self .convert_unk (i )]
140+ for i in line [0 ]
141+ ]))
142+ article_list .append (
143+ np .array ([
144+ self .article_map_title [self .convert_unk (i )]
145+ for i in line [1 ]
146+ ]))
147+ article_list .append (
148+ np .array ([
149+ self .article_map_title [self .convert_unk (i )]
150+ for i in line [0 ]
151+ ]))
152+ article_list .append (
153+ np .array ([
154+ self .article_map_content [self .convert_unk (i )]
155+ for i in line [1 ]
156+ ]))
157+ article_list .append (
158+ np .array ([
159+ self .article_map_content [self .convert_unk (i )]
160+ for i in line [0 ]
161+ ]))
104162 #output_list = [article_list,None]
105163 yield article_list
0 commit comments