Skip to content

Commit 4f06cd1

Browse files
Pick revert data generator (#32700)
* revert data_generator * add setup.py
1 parent f3436af commit 4f06cd1

File tree

2 files changed

+344
-0
lines changed

2 files changed

+344
-0
lines changed
Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import sys
17+
18+
__all__ = ['MultiSlotDataGenerator', 'MultiSlotStringDataGenerator']
19+
20+
21+
class DataGenerator(object):
22+
"""
23+
DataGenerator is a general Base class for user to inherit
24+
A user who wants to define his/her own python processing logic
25+
with paddle.fluid.dataset should inherit this class.
26+
"""
27+
28+
def __init__(self):
29+
self._proto_info = None
30+
self.batch_size_ = 32
31+
32+
def _set_line_limit(self, line_limit):
33+
if not isinstance(line_limit, int):
34+
raise ValueError("line_limit%s must be in int type" %
35+
type(line_limit))
36+
if line_limit < 1:
37+
raise ValueError("line_limit can not less than 1")
38+
self._line_limit = line_limit
39+
40+
def set_batch(self, batch_size):
41+
'''
42+
Set batch size of current DataGenerator
43+
This is necessary only if a user wants to define generator_batch
44+
45+
Example:
46+
.. code-block:: python
47+
import paddle.fluid.incubate.data_generator as dg
48+
class MyData(dg.DataGenerator):
49+
def generate_sample(self, line):
50+
def local_iter():
51+
int_words = [int(x) for x in line.split()]
52+
yield ("words", int_words)
53+
return local_iter
54+
def generate_batch(self, samples):
55+
def local_iter():
56+
for s in samples:
57+
yield ("words", s[1].extend([s[1][0]]))
58+
mydata = MyData()
59+
mydata.set_batch(128)
60+
61+
'''
62+
self.batch_size_ = batch_size
63+
64+
def run_from_memory(self):
65+
'''
66+
This function generator data from memory, it is usually used for
67+
debug and benchmarking
68+
Example:
69+
.. code-block:: python
70+
import paddle.fluid.incubate.data_generator as dg
71+
class MyData(dg.DataGenerator):
72+
def generate_sample(self, line):
73+
def local_iter():
74+
yield ("words", [1, 2, 3, 4])
75+
return local_iter
76+
mydata = MyData()
77+
mydata.run_from_memory()
78+
'''
79+
batch_samples = []
80+
line_iter = self.generate_sample(None)
81+
for user_parsed_line in line_iter():
82+
if user_parsed_line == None:
83+
continue
84+
batch_samples.append(user_parsed_line)
85+
if len(batch_samples) == self.batch_size_:
86+
batch_iter = self.generate_batch(batch_samples)
87+
for sample in batch_iter():
88+
sys.stdout.write(self._gen_str(sample))
89+
batch_samples = []
90+
if len(batch_samples) > 0:
91+
batch_iter = self.generate_batch(batch_samples)
92+
for sample in batch_iter():
93+
sys.stdout.write(self._gen_str(sample))
94+
95+
def run_from_stdin(self):
96+
'''
97+
This function reads the data row from stdin, parses it with the
98+
process function, and further parses the return value of the
99+
process function with the _gen_str function. The parsed data will
100+
be wrote to stdout and the corresponding protofile will be
101+
generated.
102+
Example:
103+
104+
.. code-block:: python
105+
import paddle.fluid.incubate.data_generator as dg
106+
class MyData(dg.DataGenerator):
107+
def generate_sample(self, line):
108+
def local_iter():
109+
int_words = [int(x) for x in line.split()]
110+
yield ("words", [int_words])
111+
return local_iter
112+
mydata = MyData()
113+
mydata.run_from_stdin()
114+
'''
115+
batch_samples = []
116+
for line in sys.stdin:
117+
line_iter = self.generate_sample(line)
118+
for user_parsed_line in line_iter():
119+
if user_parsed_line == None:
120+
continue
121+
batch_samples.append(user_parsed_line)
122+
if len(batch_samples) == self.batch_size_:
123+
batch_iter = self.generate_batch(batch_samples)
124+
for sample in batch_iter():
125+
sys.stdout.write(self._gen_str(sample))
126+
batch_samples = []
127+
if len(batch_samples) > 0:
128+
batch_iter = self.generate_batch(batch_samples)
129+
for sample in batch_iter():
130+
sys.stdout.write(self._gen_str(sample))
131+
132+
def _gen_str(self, line):
133+
'''
134+
Further processing the output of the process() function rewritten by
135+
user, outputting data that can be directly read by the datafeed,and
136+
updating proto_info information.
137+
Args:
138+
line(str): the output of the process() function rewritten by user.
139+
Returns:
140+
Return a string data that can be read directly by the datafeed.
141+
'''
142+
raise NotImplementedError(
143+
"pls use MultiSlotDataGenerator or PairWiseDataGenerator")
144+
145+
def generate_sample(self, line):
146+
'''
147+
This function needs to be overridden by the user to process the
148+
original data row into a list or tuple.
149+
Args:
150+
line(str): the original data row
151+
Returns:
152+
Returns the data processed by the user.
153+
The data format is list or tuple:
154+
[(name, [feasign, ...]), ...]
155+
or ((name, [feasign, ...]), ...)
156+
157+
For example:
158+
[("words", [1926, 08, 17]), ("label", [1])]
159+
or (("words", [1926, 08, 17]), ("label", [1]))
160+
Note:
161+
The type of feasigns must be in int or float. Once the float
162+
element appears in the feasign, the type of that slot will be
163+
processed into a float.
164+
Example:
165+
.. code-block:: python
166+
import paddle.fluid.incubate.data_generator as dg
167+
class MyData(dg.DataGenerator):
168+
def generate_sample(self, line):
169+
def local_iter():
170+
int_words = [int(x) for x in line.split()]
171+
yield ("words", [int_words])
172+
return local_iter
173+
'''
174+
raise NotImplementedError(
175+
"Please rewrite this function to return a list or tuple: " +
176+
"[(name, [feasign, ...]), ...] or ((name, [feasign, ...]), ...)")
177+
178+
def generate_batch(self, samples):
179+
'''
180+
This function needs to be overridden by the user to process the
181+
generated samples from generate_sample(self, str) function
182+
It is usually used as batch processing when a user wants to
183+
do preprocessing on a batch of samples, e.g. padding according to
184+
the max length of a sample in the batch
185+
Args:
186+
samples(list tuple): generated sample from generate_sample
187+
Returns:
188+
a python generator, the same format as return value of generate_sample
189+
Example:
190+
.. code-block:: python
191+
import paddle.fluid.incubate.data_generator as dg
192+
class MyData(dg.DataGenerator):
193+
def generate_sample(self, line):
194+
def local_iter():
195+
int_words = [int(x) for x in line.split()]
196+
yield ("words", int_words)
197+
return local_iter
198+
def generate_batch(self, samples):
199+
def local_iter():
200+
for s in samples:
201+
yield ("words", s[1].extend([s[1][0]]))
202+
mydata = MyData()
203+
mydata.set_batch(128)
204+
'''
205+
206+
def local_iter():
207+
for sample in samples:
208+
yield sample
209+
210+
return local_iter
211+
212+
213+
# TODO: guru4elephant
214+
# add more generalized DataGenerator that can adapt user-defined slot
215+
# for example, [(name, float_list), (name, str_list), (name, int_list)]
216+
class MultiSlotStringDataGenerator(DataGenerator):
217+
def _gen_str(self, line):
218+
'''
219+
Further processing the output of the process() function rewritten by
220+
user, outputting data that can be directly read by the MultiSlotDataFeed,
221+
and updating proto_info information.
222+
The input line will be in this format:
223+
>>> [(name, [str(feasign), ...]), ...]
224+
>>> or ((name, [str(feasign), ...]), ...)
225+
The output will be in this format:
226+
>>> [ids_num id1 id2 ...] ...
227+
For example, if the input is like this:
228+
>>> [("words", ["1926", "08", "17"]), ("label", ["1"])]
229+
>>> or (("words", ["1926", "08", "17"]), ("label", ["1"]))
230+
the output will be:
231+
>>> 3 1234 2345 3456 1 1
232+
Args:
233+
line(str): the output of the process() function rewritten by user.
234+
Returns:
235+
Return a string data that can be read directly by the MultiSlotDataFeed.
236+
'''
237+
if not isinstance(line, list) and not isinstance(line, tuple):
238+
raise ValueError(
239+
"the output of process() must be in list or tuple type"
240+
"Examples: [('words', ['1926', '08', '17']), ('label', ['1'])]")
241+
output = ""
242+
for index, item in enumerate(line):
243+
name, elements = item
244+
if output:
245+
output += " "
246+
out_str = []
247+
out_str.append(str(len(elements)))
248+
out_str.extend(elements)
249+
output += " ".join(out_str)
250+
return output + "\n"
251+
252+
253+
class MultiSlotDataGenerator(DataGenerator):
254+
def _gen_str(self, line):
255+
'''
256+
Further processing the output of the process() function rewritten by
257+
user, outputting data that can be directly read by the MultiSlotDataFeed,
258+
and updating proto_info information.
259+
The input line will be in this format:
260+
>>> [(name, [feasign, ...]), ...]
261+
>>> or ((name, [feasign, ...]), ...)
262+
The output will be in this format:
263+
>>> [ids_num id1 id2 ...] ...
264+
The proto_info will be in this format:
265+
>>> [(name, type), ...]
266+
267+
For example, if the input is like this:
268+
>>> [("words", [1926, 08, 17]), ("label", [1])]
269+
>>> or (("words", [1926, 08, 17]), ("label", [1]))
270+
the output will be:
271+
>>> 3 1234 2345 3456 1 1
272+
the proto_info will be:
273+
>>> [("words", "uint64"), ("label", "uint64")]
274+
Args:
275+
line(str): the output of the process() function rewritten by user.
276+
Returns:
277+
Return a string data that can be read directly by the MultiSlotDataFeed.
278+
'''
279+
if not isinstance(line, list) and not isinstance(line, tuple):
280+
raise ValueError(
281+
"the output of process() must be in list or tuple type"
282+
"Example: [('words', [1926, 08, 17]), ('label', [1])]")
283+
output = ""
284+
285+
if self._proto_info is None:
286+
self._proto_info = []
287+
for item in line:
288+
name, elements = item
289+
if not isinstance(name, str):
290+
raise ValueError("name%s must be in str type" % type(name))
291+
if not isinstance(elements, list):
292+
raise ValueError("elements%s must be in list type" %
293+
type(elements))
294+
if not elements:
295+
raise ValueError(
296+
"the elements of each field can not be empty, you need padding it in process()."
297+
)
298+
self._proto_info.append((name, "uint64"))
299+
if output:
300+
output += " "
301+
output += str(len(elements))
302+
for elem in elements:
303+
if isinstance(elem, float):
304+
self._proto_info[-1] = (name, "float")
305+
elif not isinstance(elem, int) and not isinstance(elem,
306+
long):
307+
raise ValueError(
308+
"the type of element%s must be in int or float" %
309+
type(elem))
310+
output += " " + str(elem)
311+
else:
312+
if len(line) != len(self._proto_info):
313+
raise ValueError(
314+
"the complete field set of two given line are inconsistent.")
315+
for index, item in enumerate(line):
316+
name, elements = item
317+
if not isinstance(name, str):
318+
raise ValueError("name%s must be in str type" % type(name))
319+
if not isinstance(elements, list):
320+
raise ValueError("elements%s must be in list type" %
321+
type(elements))
322+
if not elements:
323+
raise ValueError(
324+
"the elements of each field can not be empty, you need padding it in process()."
325+
)
326+
if name != self._proto_info[index][0]:
327+
raise ValueError(
328+
"the field name of two given line are not match: require<%s>, get<%s>."
329+
% (self._proto_info[index][0], name))
330+
if output:
331+
output += " "
332+
output += str(len(elements))
333+
for elem in elements:
334+
if self._proto_info[index][1] != "float":
335+
if isinstance(elem, float):
336+
self._proto_info[index] = (name, "float")
337+
elif not isinstance(elem, int) and not isinstance(elem,
338+
long):
339+
raise ValueError(
340+
"the type of element%s must be in int or float"
341+
% type(elem))
342+
output += " " + str(elem)
343+
return output + "\n"

python/setup.py.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ packages=['paddle',
188188
'paddle.fluid.transpiler',
189189
'paddle.fluid.transpiler.details',
190190
'paddle.fluid.incubate',
191+
'paddle.fluid.incubate.data_generator',
191192
'paddle.fluid.incubate.fleet',
192193
'paddle.fluid.incubate.checkpoint',
193194
'paddle.fluid.incubate.fleet.base',

0 commit comments

Comments
 (0)