Skip to content

Commit 5c63418

Browse files
authored
Merge pull request #1497 from reyoung/feature/extract_common_base_method_to_config_base
Add config_base.py for Layer
2 parents cda4579 + 6dd2165 commit 5c63418

File tree

2 files changed

+73
-72
lines changed

2 files changed

+73
-72
lines changed

python/paddle/v2/config_base.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import collections
2+
3+
from paddle.trainer_config_helpers.default_decorators import wrap_name_default
4+
import paddle.trainer_config_helpers as conf_helps
5+
6+
7+
class Layer(object):
8+
def __init__(self, name=None, parent_layers=None):
9+
assert isinstance(parent_layers, dict)
10+
self.name = name
11+
self.__parent_layers__ = parent_layers
12+
13+
def to_proto(self, context):
14+
"""
15+
function to set proto attribute
16+
"""
17+
kwargs = dict()
18+
for layer_name in self.__parent_layers__:
19+
if not isinstance(self.__parent_layers__[layer_name],
20+
collections.Sequence):
21+
v1_layer = self.__parent_layers__[layer_name].to_proto(
22+
context=context)
23+
else:
24+
v1_layer = map(lambda x: x.to_proto(context=context),
25+
self.__parent_layers__[layer_name])
26+
kwargs[layer_name] = v1_layer
27+
28+
if self.name is None:
29+
return self.to_proto_impl(**kwargs)
30+
elif self.name not in context:
31+
context[self.name] = self.to_proto_impl(**kwargs)
32+
33+
return context[self.name]
34+
35+
def to_proto_impl(self, **kwargs):
36+
raise NotImplementedError()
37+
38+
39+
def __convert_to_v2__(method_name, parent_names, is_default_name=True):
40+
if is_default_name:
41+
wrapper = wrap_name_default(name_prefix=method_name)
42+
else:
43+
wrapper = None
44+
45+
class V2LayerImpl(Layer):
46+
def __init__(self, **kwargs):
47+
parent_layers = dict()
48+
other_kwargs = dict()
49+
for pname in parent_names:
50+
if kwargs.has_key(pname):
51+
parent_layers[pname] = kwargs[pname]
52+
53+
for key in kwargs.keys():
54+
if key not in parent_names:
55+
other_kwargs[key] = kwargs[key]
56+
57+
name = kwargs.get('name', None)
58+
super(V2LayerImpl, self).__init__(name, parent_layers)
59+
self.__other_kwargs__ = other_kwargs
60+
61+
if wrapper is not None:
62+
__init__ = wrapper(__init__)
63+
64+
def to_proto_impl(self, **kwargs):
65+
args = dict()
66+
for each in kwargs:
67+
args[each] = kwargs[each]
68+
for each in self.__other_kwargs__:
69+
args[each] = self.__other_kwargs__[each]
70+
return getattr(conf_helps, method_name)(**args)
71+
72+
return V2LayerImpl

python/paddle/v2/layer.py

Lines changed: 1 addition & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,7 @@ class for each layer creation function in paddle.trainer_config_helpers.layers.
6565
Also, the creation of a protobuf message is hidden in the invocation of
6666
paddle.v2.parameters.create, no longer exposed to users.
6767
"""
68-
69-
import collections
70-
import inspect
71-
68+
from config_base import Layer, __convert_to_v2__
7269
import paddle.trainer_config_helpers as conf_helps
7370
from paddle.trainer_config_helpers.config_parser_utils import \
7471
parse_network_config as __parse__
@@ -107,74 +104,6 @@ def __real_func__():
107104
return __parse__(__real_func__)
108105

109106

110-
class Layer(object):
111-
def __init__(self, name=None, parent_layers=None):
112-
assert isinstance(parent_layers, dict)
113-
self.name = name
114-
self.__parent_layers__ = parent_layers
115-
116-
def to_proto(self, context):
117-
"""
118-
function to set proto attribute
119-
"""
120-
kwargs = dict()
121-
for layer_name in self.__parent_layers__:
122-
if not isinstance(self.__parent_layers__[layer_name],
123-
collections.Sequence):
124-
v1_layer = self.__parent_layers__[layer_name].to_proto(
125-
context=context)
126-
else:
127-
v1_layer = map(lambda x: x.to_proto(context=context),
128-
self.__parent_layers__[layer_name])
129-
kwargs[layer_name] = v1_layer
130-
131-
if self.name is None:
132-
return self.to_proto_impl(**kwargs)
133-
elif self.name not in context:
134-
context[self.name] = self.to_proto_impl(**kwargs)
135-
136-
return context[self.name]
137-
138-
def to_proto_impl(self, **kwargs):
139-
raise NotImplementedError()
140-
141-
142-
def __convert_to_v2__(method_name, parent_names, is_default_name=True):
143-
if is_default_name:
144-
wrapper = wrap_name_default(name_prefix=method_name)
145-
else:
146-
wrapper = None
147-
148-
class V2LayerImpl(Layer):
149-
def __init__(self, **kwargs):
150-
parent_layers = dict()
151-
other_kwargs = dict()
152-
for pname in parent_names:
153-
if kwargs.has_key(pname):
154-
parent_layers[pname] = kwargs[pname]
155-
156-
for key in kwargs.keys():
157-
if key not in parent_names:
158-
other_kwargs[key] = kwargs[key]
159-
160-
name = kwargs.get('name', None)
161-
super(V2LayerImpl, self).__init__(name, parent_layers)
162-
self.__other_kwargs__ = other_kwargs
163-
164-
if wrapper is not None:
165-
__init__ = wrapper(__init__)
166-
167-
def to_proto_impl(self, **kwargs):
168-
args = dict()
169-
for each in kwargs:
170-
args[each] = kwargs[each]
171-
for each in self.__other_kwargs__:
172-
args[each] = self.__other_kwargs__[each]
173-
return getattr(conf_helps, method_name)(**args)
174-
175-
return V2LayerImpl
176-
177-
178107
"""
179108
Some layer may need some special config, and can not use __convert_to_v2__ to convert.
180109
So we also need to implement some special LayerV2.

0 commit comments

Comments
 (0)