Skip to content

Commit 9e83dac

Browse files
author
yuyang18
committed
Add missing file for last commit
also rename hdfs_data to make it internally. ISSUE=4604505 git-svn-id: https://svn.baidu.com/idl/trunk/paddle@1451 1ad973e4-5ce8-4261-8a94-b56d1f490c56
1 parent eef13ff commit 9e83dac

File tree

2 files changed

+180
-18
lines changed

2 files changed

+180
-18
lines changed

demo/image_classification/upload_hadoop.sh

Lines changed: 0 additions & 18 deletions
This file was deleted.
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle.trainer.PyDataProvider2 as dp2
16+
import collections
17+
import swig_paddle
18+
19+
__all__ = ['DataProviderConverter']
20+
21+
22+
class IScanner(object):
23+
def __init__(self, input_type, pos):
24+
self.input_type = input_type
25+
assert isinstance(self.input_type, dp2.InputType)
26+
self.pos = pos
27+
28+
def scan(self, dat):
29+
pass
30+
31+
def finish_scan(self, argument):
32+
pass
33+
34+
35+
class DenseScanner(IScanner):
36+
def __init__(self, input_type, pos):
37+
IScanner.__init__(self, input_type, pos)
38+
self.__mat__ = []
39+
self.__height__ = 0
40+
41+
def scan(self, dat):
42+
self.__mat__.extend(dat)
43+
self.__height__ += 1
44+
45+
def finish_scan(self, argument):
46+
assert isinstance(argument, swig_paddle.Arguments)
47+
assert isinstance(self.input_type, dp2.InputType)
48+
m = swig_paddle.Matrix.createDense(self.__mat__,
49+
self.__height__,
50+
self.input_type.dim,
51+
False)
52+
argument.setSlotValue(self.pos, m)
53+
54+
55+
class SparseBinaryScanner(IScanner):
56+
def __init__(self, input_type, pos):
57+
IScanner.__init__(self, input_type, pos)
58+
self.__rows__ = [0]
59+
self.__cols__ = []
60+
self.__height__ = 0
61+
self.__nnz__ = 0
62+
self.__value__ = []
63+
64+
def scan(self, dat):
65+
self.extend_cols(dat)
66+
self.__rows__.append(len(dat))
67+
68+
def extend_cols(self, dat):
69+
self.__cols__.extend(dat)
70+
71+
def finish_scan(self, argument):
72+
assert isinstance(argument, swig_paddle.Arguments)
73+
assert isinstance(self.input_type, dp2.InputType)
74+
m = swig_paddle.Matrix.createSparse(self.__height__,
75+
self.input_type.dim,
76+
len(self.__cols__),
77+
len(self.__value__) == 0)
78+
assert isinstance(m, swig_paddle.Matrix)
79+
m.sparseCopyFrom(self.__rows__, self.__cols__, self.__value__)
80+
argument.setSlotValue(self.pos, m)
81+
82+
83+
class SparseFloatScanner(SparseBinaryScanner):
84+
def __init__(self, input_type, pos):
85+
SparseBinaryScanner.__init__(self, input_type, pos)
86+
87+
def extend_cols(self, dat):
88+
self.__cols__.extend((x[0] for x in dat))
89+
self.__value__.extend((x[1] for x in dat))
90+
91+
92+
class IndexScanner(IScanner):
93+
def __init__(self, input_type, pos):
94+
IScanner.__init__(self, input_type, pos)
95+
self.__ids__ = []
96+
97+
def scan(self, dat):
98+
self.__ids__.append(dat)
99+
100+
def finish_scan(self, argument):
101+
ids = swig_paddle.IVector.create(self.__ids__)
102+
assert isinstance(argument, swig_paddle.Arguments)
103+
argument.setSlotIds(self.pos, ids)
104+
105+
106+
class SequenceScanner(IScanner):
107+
def __init__(self, input_type, pos, inner_scanner, setter):
108+
IScanner.__init__(self, input_type, pos)
109+
self.__seq__ = [0]
110+
self.__inner_scanner__ = inner_scanner
111+
self.__setter__ = setter
112+
113+
def scan(self, dat):
114+
self.__seq__.append(self.__seq__[-1] + self.get_size(dat))
115+
for each in dat:
116+
self.__inner_scanner__.scan(each)
117+
118+
def finish_scan(self, argument):
119+
seq = swig_paddle.IVector.create(self.__seq__, False)
120+
self.__setter__(argument, self.pos, seq)
121+
self.__inner_scanner__.finish_scan(argument)
122+
123+
def get_size(self, dat):
124+
if isinstance(self.__inner_scanner__, SequenceScanner):
125+
return sum(self.__inner_scanner__.get_size(item) for item in dat)
126+
else:
127+
return len(dat)
128+
129+
130+
class DataProviderConverter(object):
131+
def __init__(self, input_types):
132+
self.input_types = input_types
133+
assert isinstance(self.input_types, collections.Sequence)
134+
for each in self.input_types:
135+
assert isinstance(each, dp2.InputType)
136+
137+
def convert(self, dat, argument=None):
138+
if argument is None:
139+
argument = swig_paddle.Arguments.createArguments(0)
140+
assert isinstance(argument, swig_paddle.Arguments)
141+
argument.resize(len(self.input_types))
142+
143+
scanners = [DataProviderConverter.create_scanner(i, each_type)
144+
for i, each_type in enumerate(self.input_types)]
145+
146+
for each_sample in dat:
147+
for each_step, scanner in zip(each_sample, scanners):
148+
scanner.scan(each_step)
149+
150+
for scanner in scanners:
151+
scanner.finish_scan(argument)
152+
153+
return argument
154+
155+
def __call__(self, dat, argument=None):
156+
return self.convert(dat, argument)
157+
158+
@staticmethod
159+
def create_scanner(i, each):
160+
assert isinstance(each, dp2.InputType)
161+
retv = None
162+
if each.type == dp2.DataType.Dense:
163+
retv = DenseScanner(each, i)
164+
elif each.type == dp2.DataType.Index:
165+
retv = IndexScanner(each, i)
166+
elif each.type == dp2.DataType.SparseNonValue:
167+
retv = SparseBinaryScanner(each, i)
168+
elif each.type == dp2.DataType.SparseValue:
169+
retv = SparseFloatScanner(each, i)
170+
assert retv is not None
171+
172+
if each.seq_type == dp2.SequenceType.SUB_SEQUENCE:
173+
retv = SequenceScanner(each, i, retv, lambda a, p, seq:
174+
a.setSlotSubSequenceStartPositions(p, seq))
175+
176+
if each.seq_type in [dp2.SequenceType.SUB_SEQUENCE,
177+
dp2.SequenceType.SEQUENCE]:
178+
retv = SequenceScanner(each, i, retv, lambda a, p, seq:
179+
a.setSlotSequenceStartPositions(p, seq))
180+
return retv

0 commit comments

Comments
 (0)