Skip to content

Commit aff54ef

Browse files
committed
add ctr data
1 parent 40d65a1 commit aff54ef

File tree

1 file changed

+123
-0
lines changed

1 file changed

+123
-0
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
from paddle.fluid import core
18+
from paddle.fluid.executor import global_scope
19+
from paddle.fluid.framework import default_main_program, \
20+
default_startup_program, Variable
21+
from paddle.fluid.unique_name import generate as unique_name
22+
23+
24+
def monkey_patch_reader_methods(reader):
25+
def __get_reader__():
26+
scope = global_scope()
27+
var = scope.find_var(reader.name)
28+
return var.get_reader()
29+
30+
def reset():
31+
return __get_reader__().reset()
32+
33+
reader.reset = reset
34+
reader.stop_gradient = True
35+
reader.persistable = True
36+
return reader
37+
38+
39+
def _copy_reader_var_(block, var):
40+
new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
41+
new_var.desc.set_shapes(var.desc.shapes())
42+
new_var.desc.set_dtypes(var.desc.dtypes())
43+
new_var.persistable = True
44+
return new_var
45+
46+
47+
def ctr_reader(feed_data,
48+
capacity,
49+
thread_num,
50+
batch_size,
51+
file_list,
52+
slots,
53+
name=None):
54+
"""
55+
Create a CTR reader for data feeding in Python
56+
57+
This layer returns a Reader Variable.
58+
The Reader provides :code:`decorate_paddle_reader()` and
59+
:code:`decorate_tensor_provider()` to set a Python generator as the data
60+
source in Python side. When :code:`Executor::Run()` is invoked in C++
61+
side, the data from the generator would be read automatically. Unlike
62+
:code:`DataFeeder.feed()`, the data reading process and
63+
:code:`Executor::Run()` process can run in parallel using
64+
:code:`py_reader`. The :code:`start()` method of the Reader should be
65+
called when each pass begins, while the :code:`reset()` method should be
66+
called when the pass ends and :code:`fluid.core.EOFException` raises.
67+
Note that :code:`Program.clone()` method cannot clone :code:`py_reader`.
68+
69+
Args:
70+
capacity(int): The buffer capacity maintained by :code:`py_reader`.
71+
thread_num(list|tuple): List of tuples which declaring data shapes.
72+
batch_size(list|tuple): List of strs which declaring data type.
73+
file_list(list|tuple): List of ints which declaring data lod_level.
74+
slots(bool): Whether use double buffer or not.
75+
name(basestring): The prefix Python queue name and Reader name. None will
76+
be generated automatically.
77+
78+
Returns:
79+
Variable: A Reader from which we can get feeding data.
80+
81+
Examples:
82+
83+
1. The basic usage of :code:`py_reader` is as follows:
84+
"""
85+
if name is None:
86+
queue_name = unique_name('lod_tensor_blocking_queue')
87+
reader_name = unique_name('create_ctr_reader')
88+
else:
89+
queue_name = "_".join([name, "queue"])
90+
reader_name = "_".join([name, "reader"])
91+
92+
var = global_scope().var(queue_name)
93+
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes)
94+
95+
startup_blk = default_startup_program().current_block()
96+
reader_var = startup_blk.create_var(name=reader_name)
97+
startup_blk.append_op(
98+
type='create_ctr_reader',
99+
inputs={'blocking_queue': [queue_name]},
100+
outputs={'Out': [reader_var]},
101+
attrs={
102+
'thread_num': thread_num,
103+
'batch_size': batch_size,
104+
'file_list': file_list,
105+
'slots': slots,
106+
})
107+
108+
reader_var.persistable = True
109+
110+
main_prog_reader_var = _copy_reader_var_(
111+
default_main_program().current_block(), reader_var)
112+
113+
reader = monkey_patch_reader_methods(main_prog_reader_var)
114+
115+
# monkey patch py_reader special methods
116+
reader.queue = feed_queue
117+
reader.exited = False
118+
119+
main_blk = default_main_program().current_block()
120+
main_blk.append_op(
121+
type='read', inputs={'Reader': [reader]}, outputs={'Out': feed_data})
122+
123+
return reader

0 commit comments

Comments
 (0)