Skip to content

Commit ffd045a

Browse files
committed
Fix data_layers and data_type function in topology.py
1 parent c6bfb71 commit ffd045a

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

python/paddle/v2/topology.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,18 @@ def data_layers(self):
6868
get all data layer
6969
:return:
7070
"""
71-
data_layers = set()
71+
data_layers = dict()
7272

7373
def find_data_layer(layer):
7474
if isinstance(layer, v2_layer.DataLayerV2):
75-
data_layers.add(layer)
76-
for parent_layer in layer.__parent_layers__.values():
77-
find_data_layer(parent_layer)
75+
data_layers[layer.name] = layer
76+
if not isinstance(layer, collections.Sequence):
77+
for parent_layer in layer.__parent_layers__.values():
78+
find_data_layer(parent_layer)
79+
else:
80+
for each_l in layer:
81+
for parent_layer in each_l.__parent_layers__.values():
82+
find_data_layer(parent_layer)
7883

7984
for layer in self.layers:
8085
find_data_layer(layer)
@@ -86,8 +91,12 @@ def data_type(self):
8691
get data_type from proto, such as:
8792
[('image', dense_vector(768)), ('label', integer_value(10))]
8893
"""
89-
return [(data_layer.name, data_layer.type)
90-
for data_layer in self.data_layers()]
94+
95+
data_types_lists = []
96+
for each in self.__model_config__.input_layer_names:
97+
data_layers = self.data_layers()
98+
data_types_lists.append((each, data_layers[each].type))
99+
return data_types_lists
91100

92101

93102
def __check_layer_type__(layer):

0 commit comments

Comments
 (0)